Files
ragflow/tools/scripts/db_schema_sync.py
Liu An f038a34154 Docs: Update version references to v0.25.4 in READMEs and docs (#14912)
### What problem does this PR solve?

- Update version tags in README files (including translations) from
v0.25.3 to v0.25.4
- Modify Docker image references and documentation to reflect new
version
- Update version badges and image descriptions
- Maintain consistency across all language variants of README files

### Type of change

- [x] Documentation Update
2026-05-14 11:07:08 +08:00

951 lines
35 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Database Schema Sync Script
This script synchronizes database models defined in api/db/db_models.py
with the actual database schema using peewee-migrate.
Features:
1. Reads model definitions from api/db/db_models.py
2. Compares with existing database tables specified via command line
3. Generates migration files in tools/migrate/{version}/
"""
import argparse
import importlib.util
import inspect
import logging
import os
import re
import sys
from peewee import MySQLDatabase, Model, Field
from peewee_migrate import Router
# Add project root to path for imports
PROJECT_BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, PROJECT_BASE)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def validate_version(version: str) -> bool:
"""Validate version format: vxx.xx.xx where xx are digits"""
pattern = r'^v\d+\.\d+\.\d+$'
return bool(re.match(pattern, version))
def version_to_dirname(version: str) -> str:
"""Convert version string to valid directory name (e.g., 'v0.25.4' -> 'v0_25_4')"""
return version.replace('.', '_')
def load_db_models():
"""Load database models from api/db/db_models.py"""
models_path = os.path.join(PROJECT_BASE, 'api', 'db', 'db_models.py')
if not os.path.exists(models_path):
raise FileNotFoundError(f"db_models.py not found at {models_path}")
# Import the module
spec = importlib.util.spec_from_file_location("db_models", models_path)
db_models = importlib.util.module_from_spec(spec)
spec.loader.exec_module(db_models)
# Get all Model subclasses
models = []
for name, obj in inspect.getmembers(db_models):
if inspect.isclass(obj) and issubclass(obj, Model) and obj is not Model:
# Skip base model classes
if obj.__name__ in ['BaseModel', 'DataBaseModel']:
continue
# Check if it has a database attribute (is a proper model)
if hasattr(obj._meta, 'database'):
models.append(obj)
return models, db_models
def create_database_connection(host: str, port: int, user: str, password: str, database: str):
"""Create MySQL database connection from command line arguments"""
db = MySQLDatabase(
database,
host=host,
port=port,
user=user,
password=password,
charset='utf8mb4'
)
return db
# MySQL type to Peewee field type mapping
MYSQL_TO_PEEWEE_TYPE = {
'varchar': 'CharField',
'char': 'CharField',
'text': 'TextField',
'longtext': 'TextField',
'mediumtext': 'TextField',
'int': 'IntegerField',
'integer': 'IntegerField',
'bigint': 'BigIntegerField',
'float': 'FloatField',
'double': 'FloatField',
'decimal': 'FloatField',
'datetime': 'DateTimeField',
'timestamp': 'DateTimeField',
'tinyint(1)': 'BooleanField',
'tinyint': 'IntegerField',
'smallint': 'IntegerField',
'mediumint': 'IntegerField',
}
PEEWEE_TO_MYSQL_TYPE = {
'CharField': 'varchar',
'TextField': 'text',
'IntegerField': 'int',
'BigIntegerField': 'bigint',
'FloatField': 'float',
'BooleanField': 'tinyint',
'DateTimeField': 'datetime',
}
def get_table_columns(db, table_name: str) -> dict:
"""Get column information from database table
Returns:
dict: {column_name: {type, nullable, default, ...}}
"""
cursor = db.execute_sql("""
SELECT
column_name,
data_type,
column_type,
is_nullable,
column_default,
column_key,
extra
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""", (db.database, table_name))
columns = {}
for row in cursor.fetchall():
col_name = row[0]
data_type = row[1].lower()
column_type = row[2].lower()
is_nullable = row[3] == 'YES'
column_default = row[4]
column_key = row[5]
extra = row[6] or ''
# Determine peewee type
if column_type.startswith('tinyint(1)'):
peewee_type = 'BooleanField'
else:
peewee_type = MYSQL_TO_PEEWEE_TYPE.get(data_type, 'TextField')
columns[col_name] = {
'data_type': data_type,
'column_type': column_type,
'peewee_type': peewee_type,
'nullable': is_nullable,
'default': column_default,
'is_primary': column_key == 'PRI',
'extra': extra,
}
return columns
def get_peewee_field_type(field: Field) -> str:
"""Get peewee field type name"""
field_class = field.__class__.__name__
return field_class
def get_base_field_type(field: Field) -> str:
"""Get base peewee field type by walking the MRO chain.
Custom field types (like DateTimeTzField, JSONField) inherit from standard types.
This function returns the underlying standard type for comparison.
"""
# Standard peewee field types we consider as "base" types
STANDARD_TYPES = {
'CharField', 'TextField', 'IntegerField', 'BigIntegerField',
'FloatField', 'BooleanField', 'DateTimeField', 'DateField',
'TimeField', 'DecimalField', 'ForeignKeyField', 'ManyToManyField',
'PrimaryKeyField', 'AutoField'
}
# Walk through the MRO (Method Resolution Order) to find standard type
for cls in field.__class__.__mro__:
class_name = cls.__name__
if class_name in STANDARD_TYPES:
return class_name
# Fallback to TextField if no standard type found
return 'TextField'
def normalize_field_type(field: Field) -> str:
"""Normalize field type for comparison using base type"""
return get_base_field_type(field)
def compare_fields(model_fields: dict, db_columns: dict) -> dict:
"""Compare model fields with database columns
Returns:
dict: {
'added': {field_name: field_obj}, # New fields not in DB
'changed': {field_name: (old_info, new_field)}, # Type changed
'removed': {field_name: col_info}, # Fields in DB but not in model
}
"""
result = {
'added': {},
'changed': {},
'removed': {},
}
# Skip auto-generated fields like id, create_time, etc.
skip_fields = {'id'}
for field_name, field in model_fields.items():
if field_name in skip_fields:
continue
# Check if field exists in database
if field_name not in db_columns:
result['added'][field_name] = field
logger.info(f" New field detected: {field_name} ({field.__class__.__name__})")
else:
# Check if type changed
db_col = db_columns[field_name]
model_base_type = normalize_field_type(field)
db_type = db_col['peewee_type']
# Type mismatch
if model_base_type != db_type:
result['changed'][field_name] = (db_col, field)
logger.info(f" Field type changed: {field_name} ({db_type} -> {model_base_type}, actual: {field.__class__.__name__})")
# Detect removed fields: columns in DB but not in model
for col_name, col_info in db_columns.items():
if col_name in skip_fields:
continue
if col_name not in model_fields:
result['removed'][col_name] = col_info
logger.info(f" Removed field detected: {col_name} ({col_info['column_type']})")
return result
def generate_field_code(field: Field, field_name: str) -> str:
"""Generate peewee field definition code"""
field_class = field.__class__.__name__
# Map custom field types to standard peewee types for migration
# These custom types will be stored as their underlying standard type
custom_to_standard = {
'LongTextField': 'TextField',
'JSONField': 'TextField',
'ListField': 'TextField',
'SerializedField': 'TextField',
'DateTimeTzField': 'CharField',
}
# Use standard type for custom fields
pw_field_class = custom_to_standard.get(field_class, field_class)
# Build field arguments
args = []
# max_length for CharField
if pw_field_class == 'CharField' and hasattr(field, 'max_length') and field.max_length is not None:
args.append(f"max_length={field.max_length}")
# null
if field.null:
args.append("null=True")
# default
if field.default is not None:
default_val = field.default
if isinstance(default_val, str):
# Escape quotes in string
escaped = default_val.replace("'", "\\'")
args.append(f"default='{escaped}'")
elif isinstance(default_val, bool):
args.append(f"default={'True' if default_val else 'False'}")
elif isinstance(default_val, (int, float)):
args.append(f"default={default_val}")
elif isinstance(default_val, dict):
args.append(f"default={default_val}")
elif isinstance(default_val, list):
args.append(f"default={default_val}")
# index
if getattr(field, 'index', False):
args.append("index=True")
# unique
if getattr(field, 'unique', False):
args.append("unique=True")
args_str = ', '.join(args)
return f"pw.{pw_field_class}({args_str})"
def generate_add_field_sql(table_name: str, field: Field, field_name: str) -> str:
"""Generate raw SQL for adding a field to MySQL table.
This is used for existing tables where migrator.add_fields doesn't work
because the model is not registered in migrator.orm.
"""
field_class = field.__class__.__name__
# Determine MySQL column type
mysql_type_map = {
'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
'TextField': 'LONGTEXT',
'LongTextField': 'LONGTEXT',
'JSONField': 'LONGTEXT',
'ListField': 'LONGTEXT',
'SerializedField': 'LONGTEXT',
'IntegerField': 'INT',
'BigIntegerField': 'BIGINT',
'FloatField': 'DOUBLE',
'BooleanField': 'TINYINT(1)',
'DateTimeField': 'DATETIME',
'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
}
mysql_type = mysql_type_map.get(field_class, 'LONGTEXT')
# Build column definition
parts = [f'`{field_name}`', mysql_type]
# NULL/NOT NULL
if field.null:
parts.append('NULL')
else:
parts.append('NOT NULL')
# DEFAULT
if field.default is not None:
default_val = field.default
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
elif isinstance(default_val, dict) or isinstance(default_val, list):
import json
escaped = json.dumps(default_val).replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
# COMMENT
if hasattr(field, 'help_text') and field.help_text:
escaped = field.help_text.replace("'", "''")
parts.append(f"COMMENT '{escaped}'")
sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}"
# Add index if needed
index_sql = None
if getattr(field, 'index', False):
index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)"
return sql, index_sql
def generate_drop_field_sql(table_name: str, field_name: str) -> str:
"""Generate SQL for dropping a field from a table."""
return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`"
def generate_rollback_field_sql(table_name: str, field_name: str) -> str:
"""Generate SQL for removing a field."""
return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`"
def generate_rollback_add_field_sql(table_name: str, col_info: dict, field_name: str) -> str:
"""Generate SQL for rolling back a dropped field (re-adding it).
This reconstructs the ADD COLUMN statement from the column info
that was captured before the field was dropped.
"""
mysql_type = col_info.get('column_type', 'LONGTEXT')
parts = [f'`{field_name}`', mysql_type]
# NULL/NOT NULL
if col_info.get('nullable', True):
parts.append('NULL')
else:
parts.append('NOT NULL')
# DEFAULT
default_val = col_info.get('default')
if default_val is not None:
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}"
# Re-add index if it was a non-primary key
index_sql = None
if col_info.get('column_key') == 'MUL':
index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)"
return sql, index_sql
def generate_rollback_modify_sql(table_name: str, old_info: dict, field_name: str) -> str:
"""Generate SQL for rolling back a field type change.
Note: This restores the column type, but data values may need manual handling
if the type conversion caused data loss or transformation.
"""
# Reconstruct MySQL type from old_info
mysql_type = old_info.get('column_type', 'LONGTEXT')
# Build column definition
parts = [f'`{field_name}`', mysql_type]
# NULL/NOT NULL
if old_info.get('nullable', True):
parts.append('NULL')
else:
parts.append('NOT NULL')
# DEFAULT (if available)
if old_info.get('default') is not None:
default_val = old_info['default']
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}"
def generate_modify_field_sql(table_name: str, field: Field, field_name: str) -> str:
"""Generate SQL for modifying a field in MySQL table."""
field_class = field.__class__.__name__
# Determine MySQL column type
mysql_type_map = {
'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
'TextField': 'LONGTEXT',
'LongTextField': 'LONGTEXT',
'JSONField': 'LONGTEXT',
'ListField': 'LONGTEXT',
'SerializedField': 'LONGTEXT',
'IntegerField': 'INT',
'BigIntegerField': 'BIGINT',
'FloatField': 'DOUBLE',
'BooleanField': 'TINYINT(1)',
'DateTimeField': 'DATETIME',
'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
}
mysql_type = mysql_type_map.get(field_class, 'LONGTEXT')
# Build column definition
parts = [f'`{field_name}`', mysql_type]
# NULL/NOT NULL
if field.null:
parts.append('NULL')
else:
parts.append('NOT NULL')
# DEFAULT
if field.default is not None:
default_val = field.default
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
elif isinstance(default_val, dict) or isinstance(default_val, list):
import json
escaped = json.dumps(default_val).replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
# COMMENT
if hasattr(field, 'help_text') and field.help_text:
escaped = field.help_text.replace("'", "''")
parts.append(f"COMMENT '{escaped}'")
return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}"
def generate_migration_content(new_tables: list, field_changes: dict, migrate_dir: str, migration_name: str, drop_fields: bool = False) -> str:
"""Generate migration file content"""
lines = [
'"""Peewee migrations."""',
'',
'from contextlib import suppress',
'',
'import peewee as pw',
'from peewee_migrate import Migrator',
'',
'',
'with suppress(ImportError):',
' import playhouse.postgres_ext as pw_pext',
'',
'',
'def migrate(migrator: Migrator, database: pw.Database, *, fake=False):',
' """Write your migrations here."""',
'',
]
# Generate create_model for new tables
for model in new_tables:
table_name = model._meta.table_name
model_name = model.__name__
lines.append(' @migrator.create_model')
lines.append(f' class {model_name}(pw.Model):')
# Get all fields
fields = model._meta.fields
for field_name, field in fields.items():
field_code = generate_field_code(field, field_name)
lines.append(f' {field_name} = {field_code}')
lines.append('')
lines.append(' class Meta:')
lines.append(f' table_name = "{table_name}"')
# Add indexes if defined
indexes = getattr(model._meta, 'indexes', None)
if indexes:
lines.append(f' indexes = {indexes}')
lines.append('')
# Generate SQL for adding new fields to existing tables
for table_name, changes in field_changes.items():
if changes.get('added'):
for field_name, field in changes['added'].items():
sql, index_sql = generate_add_field_sql(table_name, field, field_name)
lines.append(f' migrator.sql("{sql}")')
if index_sql:
lines.append(f' migrator.sql("{index_sql}")')
lines.append('')
# Generate SQL for modifying fields in existing tables
for table_name, changes in field_changes.items():
if changes.get('changed'):
for field_name, (old_info, field) in changes['changed'].items():
modify_sql = generate_modify_field_sql(table_name, field, field_name)
lines.append(f' migrator.sql("{modify_sql}")')
lines.append('')
# Generate SQL for dropping removed fields from existing tables
if drop_fields:
for table_name, changes in field_changes.items():
if changes.get('removed'):
for field_name, col_info in changes['removed'].items():
drop_sql = generate_drop_field_sql(table_name, field_name)
lines.append(f' # WARNING: Dropping column `{field_name}` from `{table_name}` - this will permanently delete data!')
lines.append(f' migrator.sql("{drop_sql}")')
lines.append('')
# Generate rollback
lines.append('')
lines.append('def rollback(migrator: Migrator, database: pw.Database, *, fake=False):')
lines.append(' """Write your rollback migrations here."""')
lines.append('')
# Rollback: re-add dropped fields (before other rollbacks, since they may depend on these fields)
if drop_fields:
for table_name, changes in field_changes.items():
if changes.get('removed'):
for field_name, col_info in changes['removed'].items():
add_sql, index_sql = generate_rollback_add_field_sql(table_name, col_info, field_name)
lines.append(f' # Re-add dropped column `{field_name}` to `{table_name}` (data is lost)')
lines.append(f' migrator.sql("{add_sql}")')
if index_sql:
lines.append(f' migrator.sql("{index_sql}")')
# Rollback: reverse field type changes first (before removing added fields)
for table_name, changes in field_changes.items():
if changes.get('changed'):
for field_name, (old_info, field) in changes['changed'].items():
rollback_modify_sql = generate_rollback_modify_sql(table_name, old_info, field_name)
lines.append(' # Note: Data values may need manual handling if type conversion caused data loss')
lines.append(f' migrator.sql("{rollback_modify_sql}")')
# Rollback: remove added fields using SQL
for table_name, changes in field_changes.items():
if changes.get('added'):
for field_name in changes['added'].keys():
rollback_sql = generate_rollback_field_sql(table_name, field_name)
lines.append(f' migrator.sql("{rollback_sql}")')
# Rollback: remove tables (in reverse order)
for model in reversed(new_tables):
table_name = model._meta.table_name
lines.append(f' migrator.remove_model("{table_name}")')
lines.append('')
return '\n'.join(lines)
def create_migration(router: Router, models: list, db, name: str = "auto", drop_fields: bool = False):
"""Create a new migration by auto-detecting model changes
Detects:
1. New tables -> generate create_model
2. New fields in existing tables -> generate add_fields
3. Field type changes -> generate change_fields
4. Removed fields (only when --drop is specified) -> generate drop_fields
Args:
router: peewee-migrate Router instance
models: List of model classes to compare against database
db: Database connection
name: Migration name
drop_fields: Whether to include DROP COLUMN for removed fields
"""
try:
# Get existing tables from database
cursor = db.execute_sql(
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s",
(db.database,)
)
existing_tables = {row[0] for row in cursor.fetchall()}
new_tables = []
field_changes = {}
for model in models:
table_name = model._meta.table_name
if table_name not in existing_tables:
# New table
new_tables.append(model)
logger.info(f"New table detected: {table_name}")
else:
# Existing table - check for field changes
logger.info(f"Checking existing table: {table_name}")
# Get model fields (exclude auto-generated)
model_fields = {}
for field_name, field in model._meta.fields.items():
# Skip id and base model fields
if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'):
continue
if hasattr(field, '_auto_created') and field._auto_created:
continue
model_fields[field_name] = field
# Get database columns
db_columns = get_table_columns(db, table_name)
# Compare
changes = compare_fields(model_fields, db_columns)
if changes['added'] or changes['changed'] or changes['removed']:
field_changes[table_name] = changes
# Check if any changes detected
has_removed = any(changes.get('removed') for changes in field_changes.values())
if not drop_fields and has_removed:
removed_details = []
for table_name, changes in field_changes.items():
if changes.get('removed'):
for col_name in changes['removed']:
removed_details.append(f"{table_name}.{col_name}")
logger.warning(f"Removed fields detected (not included in migration, use --drop to include): {', '.join(removed_details)}")
# Remove 'removed' from changes since we're not acting on them
for table_name in field_changes:
field_changes[table_name]['removed'] = {}
if not new_tables and not any(changes['added'] or changes['changed'] for changes in field_changes.values()):
if not (drop_fields and has_removed):
logger.info("No schema changes detected, migration not created")
return None
# Generate migration file content
migration_content = generate_migration_content(new_tables, field_changes, router.migrate_dir, name, drop_fields=drop_fields)
# Get next migration number (count existing migration files)
existing_migrations = [f for f in os.listdir(router.migrate_dir) if f.endswith('.py') and not f.startswith('_')]
migration_num = len(existing_migrations) + 1
migration_file = os.path.join(router.migrate_dir, f'{migration_num:03d}_{name}.py')
with open(migration_file, 'w') as f:
f.write(migration_content)
logger.info(f"Created migration: {migration_file}")
return migration_file
except Exception as e:
logger.error(f"Failed to create migration: {e}")
raise
def run_migrations(router: Router):
"""Run all pending migrations"""
try:
diff = router.diff
if not diff:
logger.info("No pending migrations to run")
return
router.run()
logger.info("Migrations completed successfully")
except Exception as e:
logger.error(f"Failed to run migrations: {e}")
raise
def list_migrations(router: Router):
"""List all migrations"""
todo = router.todo
if not todo:
logger.info("No migration files found")
return
logger.info("Available migrations:")
done = set(router.done)
for migration in todo:
status = "applied" if migration in done else "pending"
logger.info(f" [{status}] {migration}")
def diff_schema(models: list, db):
"""Show schema differences between models and database"""
logger.info("Checking schema differences...")
# Tables to ignore (managed by peewee-migrate)
IGNORE_TABLES = {'migratehistory'}
# Get all model table names
model_tables = set()
for model in models:
table_name = model._meta.table_name
model_tables.add(table_name)
logger.info(f"Found {len(model_tables)} model tables")
# Get existing tables from database
cursor = db.execute_sql(
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s",
(db.database,)
)
existing_tables = {row[0] for row in cursor.fetchall() if row[0] not in IGNORE_TABLES}
# Find tables that exist in models but not in database
missing_tables = model_tables - existing_tables
if missing_tables:
logger.warning(f"Tables not in database ({len(missing_tables)}): {', '.join(sorted(missing_tables))}")
# Find tables that exist in database but not in models
extra_tables = existing_tables - model_tables
if extra_tables:
logger.info(f"Tables in database but not in models: {', '.join(sorted(extra_tables))}")
# Check field differences for existing tables
common_tables = model_tables & existing_tables
if common_tables:
logger.info(f"\nChecking field differences for {len(common_tables)} existing tables...")
total_added = 0
total_changed = 0
total_removed = 0
for model in models:
table_name = model._meta.table_name
if table_name not in common_tables:
continue
# Get model fields
model_fields = {}
for field_name, field in model._meta.fields.items():
if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'):
continue
model_fields[field_name] = field
# Get database columns
db_columns = get_table_columns(db, table_name)
# Compare
changes = compare_fields(model_fields, db_columns)
if changes['added']:
total_added += len(changes['added'])
field_details = [f"{k}:{v.__class__.__name__}" for k, v in changes['added'].items()]
logger.info(f" {table_name}: {len(changes['added'])} new field(s) - {field_details}")
if changes['changed']:
total_changed += len(changes['changed'])
field_details = [f"{k}:{v[1].__class__.__name__}" for k, v in changes['changed'].items()]
logger.info(f" {table_name}: {len(changes['changed'])} changed field(s) - {field_details}")
if changes['removed']:
total_removed += len(changes['removed'])
field_details = [f"{k}:{v['column_type']}" for k, v in changes['removed'].items()]
logger.warning(f" {table_name}: {len(changes['removed'])} removed field(s) - {field_details}")
logger.info(f"\nSummary: {total_added} new fields, {total_changed} changed fields, {total_removed} removed fields")
def main():
parser = argparse.ArgumentParser(
description='Database Schema Synchronization Tool using peewee-migrate',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# List all migrations
python db_schema_sync.py --list --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
# Create migration from model changes
python db_schema_sync.py --create --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
# Create migration including dropped fields (destructive!)
python db_schema_sync.py --create --drop --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
# Run all pending migrations
python db_schema_sync.py --migrate --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
# Show schema differences
python db_schema_sync.py --diff --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
"""
)
# Database connection options
parser.add_argument('--host', type=str, required=True, help='MySQL host')
parser.add_argument('--port', type=int, default=3306, help='MySQL port (default: 3306)')
parser.add_argument('--user', type=str, required=True, help='MySQL user')
parser.add_argument('--password', type=str, required=True, help='MySQL password')
parser.add_argument('--database', type=str, required=True, help='MySQL database name')
# Version option
parser.add_argument('--version', '-v', type=str, required=True,
help='Version number in format vxx.xx.xx (e.g., v0.25.4)')
# Action options
parser.add_argument('--list', '-l', action='store_true', help='List all migrations')
parser.add_argument('--create', '-c', action='store_true',
help='Create migration from model changes (auto-detect)')
parser.add_argument('--migrate', '-m', action='store_true', help='Run pending migrations')
parser.add_argument('--diff', '-d', action='store_true', help='Show schema differences')
# Migration options
parser.add_argument('--name', '-n', type=str, default='auto', help='Migration name')
parser.add_argument('--drop', action='store_true',
help='Include DROP COLUMN for fields removed from models (destructive - will permanently delete data!)')
args = parser.parse_args()
# Validate version format
if not validate_version(args.version):
logger.error(f"Invalid version format: {args.version}. Expected format: vxx.xx.xx (e.g., v0.25.4)")
sys.exit(1)
# Validate at least one action is specified
if not any([args.list, args.create, args.migrate, args.diff]):
parser.print_help()
logger.error("Please specify at least one action: --list, --create, --migrate, or --diff")
sys.exit(1)
# Convert version to directory name
version_dir = version_to_dirname(args.version)
migrate_dir = os.path.join(PROJECT_BASE, 'tools', 'migrate', version_dir)
logger.info(f"Version: {args.version}")
logger.info(f"Migration directory: {migrate_dir}")
# Create migration directory if it doesn't exist
os.makedirs(migrate_dir, exist_ok=True)
# Load database models
logger.info("Loading database models from api/db/db_models.py...")
models, _ = load_db_models()
logger.info(f"Found {len(models)} model classes")
# Create database connection
db = create_database_connection(
host=args.host,
port=args.port,
user=args.user,
password=args.password,
database=args.database
)
try:
db.connect()
logger.info(f"Connected to database: {args.database}")
# Create router
router = Router(
db,
migrate_dir,
ignore=['basemodel', 'base_model', 'migratehistory']
)
# Execute requested actions
if args.list:
list_migrations(router)
if args.create:
create_migration(router, models, db, args.name, drop_fields=args.drop)
if args.migrate:
run_migrations(router)
if args.diff:
diff_schema(models, db)
finally:
if not db.is_closed():
db.close()
logger.info("Database connection closed")
logger.info("Done.")
if __name__ == '__main__':
main()