mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-19 07:46:59 +08:00
### What problem does this PR solve? - Update version tags in README files (including translations) from v0.25.3 to v0.25.4 - Modify Docker image references and documentation to reflect new version - Update version badges and image descriptions - Maintain consistency across all language variants of README files ### Type of change - [x] Documentation Update
951 lines
35 KiB
Python
951 lines
35 KiB
Python
#
|
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""
|
|
Database Schema Sync Script
|
|
|
|
This script synchronizes database models defined in api/db/db_models.py
|
|
with the actual database schema using peewee-migrate.
|
|
|
|
Features:
|
|
1. Reads model definitions from api/db/db_models.py
|
|
2. Compares with existing database tables specified via command line
|
|
3. Generates migration files in tools/migrate/{version}/
|
|
"""
|
|
|
|
import argparse
|
|
import importlib.util
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
|
|
from peewee import MySQLDatabase, Model, Field
|
|
from peewee_migrate import Router
|
|
|
|
# Add project root to path for imports
|
|
PROJECT_BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
sys.path.insert(0, PROJECT_BASE)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_version(version: str) -> bool:
|
|
"""Validate version format: vxx.xx.xx where xx are digits"""
|
|
pattern = r'^v\d+\.\d+\.\d+$'
|
|
return bool(re.match(pattern, version))
|
|
|
|
|
|
def version_to_dirname(version: str) -> str:
|
|
"""Convert version string to valid directory name (e.g., 'v0.25.4' -> 'v0_25_4')"""
|
|
return version.replace('.', '_')
|
|
|
|
|
|
def load_db_models():
|
|
"""Load database models from api/db/db_models.py"""
|
|
models_path = os.path.join(PROJECT_BASE, 'api', 'db', 'db_models.py')
|
|
|
|
if not os.path.exists(models_path):
|
|
raise FileNotFoundError(f"db_models.py not found at {models_path}")
|
|
|
|
# Import the module
|
|
spec = importlib.util.spec_from_file_location("db_models", models_path)
|
|
db_models = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(db_models)
|
|
|
|
# Get all Model subclasses
|
|
models = []
|
|
for name, obj in inspect.getmembers(db_models):
|
|
if inspect.isclass(obj) and issubclass(obj, Model) and obj is not Model:
|
|
# Skip base model classes
|
|
if obj.__name__ in ['BaseModel', 'DataBaseModel']:
|
|
continue
|
|
# Check if it has a database attribute (is a proper model)
|
|
if hasattr(obj._meta, 'database'):
|
|
models.append(obj)
|
|
|
|
return models, db_models
|
|
|
|
|
|
def create_database_connection(host: str, port: int, user: str, password: str, database: str):
|
|
"""Create MySQL database connection from command line arguments"""
|
|
db = MySQLDatabase(
|
|
database,
|
|
host=host,
|
|
port=port,
|
|
user=user,
|
|
password=password,
|
|
charset='utf8mb4'
|
|
)
|
|
return db
|
|
|
|
|
|
# MySQL type to Peewee field type mapping
|
|
MYSQL_TO_PEEWEE_TYPE = {
|
|
'varchar': 'CharField',
|
|
'char': 'CharField',
|
|
'text': 'TextField',
|
|
'longtext': 'TextField',
|
|
'mediumtext': 'TextField',
|
|
'int': 'IntegerField',
|
|
'integer': 'IntegerField',
|
|
'bigint': 'BigIntegerField',
|
|
'float': 'FloatField',
|
|
'double': 'FloatField',
|
|
'decimal': 'FloatField',
|
|
'datetime': 'DateTimeField',
|
|
'timestamp': 'DateTimeField',
|
|
'tinyint(1)': 'BooleanField',
|
|
'tinyint': 'IntegerField',
|
|
'smallint': 'IntegerField',
|
|
'mediumint': 'IntegerField',
|
|
}
|
|
|
|
PEEWEE_TO_MYSQL_TYPE = {
|
|
'CharField': 'varchar',
|
|
'TextField': 'text',
|
|
'IntegerField': 'int',
|
|
'BigIntegerField': 'bigint',
|
|
'FloatField': 'float',
|
|
'BooleanField': 'tinyint',
|
|
'DateTimeField': 'datetime',
|
|
}
|
|
|
|
|
|
def get_table_columns(db, table_name: str) -> dict:
|
|
"""Get column information from database table
|
|
|
|
Returns:
|
|
dict: {column_name: {type, nullable, default, ...}}
|
|
"""
|
|
cursor = db.execute_sql("""
|
|
SELECT
|
|
column_name,
|
|
data_type,
|
|
column_type,
|
|
is_nullable,
|
|
column_default,
|
|
column_key,
|
|
extra
|
|
FROM information_schema.columns
|
|
WHERE table_schema = %s AND table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (db.database, table_name))
|
|
|
|
columns = {}
|
|
for row in cursor.fetchall():
|
|
col_name = row[0]
|
|
data_type = row[1].lower()
|
|
column_type = row[2].lower()
|
|
is_nullable = row[3] == 'YES'
|
|
column_default = row[4]
|
|
column_key = row[5]
|
|
extra = row[6] or ''
|
|
|
|
# Determine peewee type
|
|
if column_type.startswith('tinyint(1)'):
|
|
peewee_type = 'BooleanField'
|
|
else:
|
|
peewee_type = MYSQL_TO_PEEWEE_TYPE.get(data_type, 'TextField')
|
|
|
|
columns[col_name] = {
|
|
'data_type': data_type,
|
|
'column_type': column_type,
|
|
'peewee_type': peewee_type,
|
|
'nullable': is_nullable,
|
|
'default': column_default,
|
|
'is_primary': column_key == 'PRI',
|
|
'extra': extra,
|
|
}
|
|
|
|
return columns
|
|
|
|
|
|
def get_peewee_field_type(field: Field) -> str:
|
|
"""Get peewee field type name"""
|
|
field_class = field.__class__.__name__
|
|
return field_class
|
|
|
|
|
|
def get_base_field_type(field: Field) -> str:
|
|
"""Get base peewee field type by walking the MRO chain.
|
|
|
|
Custom field types (like DateTimeTzField, JSONField) inherit from standard types.
|
|
This function returns the underlying standard type for comparison.
|
|
"""
|
|
# Standard peewee field types we consider as "base" types
|
|
STANDARD_TYPES = {
|
|
'CharField', 'TextField', 'IntegerField', 'BigIntegerField',
|
|
'FloatField', 'BooleanField', 'DateTimeField', 'DateField',
|
|
'TimeField', 'DecimalField', 'ForeignKeyField', 'ManyToManyField',
|
|
'PrimaryKeyField', 'AutoField'
|
|
}
|
|
|
|
# Walk through the MRO (Method Resolution Order) to find standard type
|
|
for cls in field.__class__.__mro__:
|
|
class_name = cls.__name__
|
|
if class_name in STANDARD_TYPES:
|
|
return class_name
|
|
|
|
# Fallback to TextField if no standard type found
|
|
return 'TextField'
|
|
|
|
|
|
def normalize_field_type(field: Field) -> str:
|
|
"""Normalize field type for comparison using base type"""
|
|
return get_base_field_type(field)
|
|
|
|
|
|
def compare_fields(model_fields: dict, db_columns: dict) -> dict:
|
|
"""Compare model fields with database columns
|
|
|
|
Returns:
|
|
dict: {
|
|
'added': {field_name: field_obj}, # New fields not in DB
|
|
'changed': {field_name: (old_info, new_field)}, # Type changed
|
|
'removed': {field_name: col_info}, # Fields in DB but not in model
|
|
}
|
|
"""
|
|
result = {
|
|
'added': {},
|
|
'changed': {},
|
|
'removed': {},
|
|
}
|
|
|
|
# Skip auto-generated fields like id, create_time, etc.
|
|
skip_fields = {'id'}
|
|
|
|
for field_name, field in model_fields.items():
|
|
if field_name in skip_fields:
|
|
continue
|
|
|
|
# Check if field exists in database
|
|
if field_name not in db_columns:
|
|
result['added'][field_name] = field
|
|
logger.info(f" New field detected: {field_name} ({field.__class__.__name__})")
|
|
else:
|
|
# Check if type changed
|
|
db_col = db_columns[field_name]
|
|
model_base_type = normalize_field_type(field)
|
|
db_type = db_col['peewee_type']
|
|
|
|
# Type mismatch
|
|
if model_base_type != db_type:
|
|
result['changed'][field_name] = (db_col, field)
|
|
logger.info(f" Field type changed: {field_name} ({db_type} -> {model_base_type}, actual: {field.__class__.__name__})")
|
|
|
|
# Detect removed fields: columns in DB but not in model
|
|
for col_name, col_info in db_columns.items():
|
|
if col_name in skip_fields:
|
|
continue
|
|
if col_name not in model_fields:
|
|
result['removed'][col_name] = col_info
|
|
logger.info(f" Removed field detected: {col_name} ({col_info['column_type']})")
|
|
|
|
return result
|
|
|
|
|
|
def generate_field_code(field: Field, field_name: str) -> str:
|
|
"""Generate peewee field definition code"""
|
|
field_class = field.__class__.__name__
|
|
|
|
# Map custom field types to standard peewee types for migration
|
|
# These custom types will be stored as their underlying standard type
|
|
custom_to_standard = {
|
|
'LongTextField': 'TextField',
|
|
'JSONField': 'TextField',
|
|
'ListField': 'TextField',
|
|
'SerializedField': 'TextField',
|
|
'DateTimeTzField': 'CharField',
|
|
}
|
|
|
|
# Use standard type for custom fields
|
|
pw_field_class = custom_to_standard.get(field_class, field_class)
|
|
|
|
# Build field arguments
|
|
args = []
|
|
|
|
# max_length for CharField
|
|
if pw_field_class == 'CharField' and hasattr(field, 'max_length') and field.max_length is not None:
|
|
args.append(f"max_length={field.max_length}")
|
|
|
|
# null
|
|
if field.null:
|
|
args.append("null=True")
|
|
|
|
# default
|
|
if field.default is not None:
|
|
default_val = field.default
|
|
if isinstance(default_val, str):
|
|
# Escape quotes in string
|
|
escaped = default_val.replace("'", "\\'")
|
|
args.append(f"default='{escaped}'")
|
|
elif isinstance(default_val, bool):
|
|
args.append(f"default={'True' if default_val else 'False'}")
|
|
elif isinstance(default_val, (int, float)):
|
|
args.append(f"default={default_val}")
|
|
elif isinstance(default_val, dict):
|
|
args.append(f"default={default_val}")
|
|
elif isinstance(default_val, list):
|
|
args.append(f"default={default_val}")
|
|
|
|
# index
|
|
if getattr(field, 'index', False):
|
|
args.append("index=True")
|
|
|
|
# unique
|
|
if getattr(field, 'unique', False):
|
|
args.append("unique=True")
|
|
|
|
args_str = ', '.join(args)
|
|
return f"pw.{pw_field_class}({args_str})"
|
|
|
|
|
|
def generate_add_field_sql(table_name: str, field: Field, field_name: str) -> str:
|
|
"""Generate raw SQL for adding a field to MySQL table.
|
|
|
|
This is used for existing tables where migrator.add_fields doesn't work
|
|
because the model is not registered in migrator.orm.
|
|
"""
|
|
field_class = field.__class__.__name__
|
|
|
|
# Determine MySQL column type
|
|
mysql_type_map = {
|
|
'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
|
|
'TextField': 'LONGTEXT',
|
|
'LongTextField': 'LONGTEXT',
|
|
'JSONField': 'LONGTEXT',
|
|
'ListField': 'LONGTEXT',
|
|
'SerializedField': 'LONGTEXT',
|
|
'IntegerField': 'INT',
|
|
'BigIntegerField': 'BIGINT',
|
|
'FloatField': 'DOUBLE',
|
|
'BooleanField': 'TINYINT(1)',
|
|
'DateTimeField': 'DATETIME',
|
|
'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
|
|
}
|
|
|
|
mysql_type = mysql_type_map.get(field_class, 'LONGTEXT')
|
|
|
|
# Build column definition
|
|
parts = [f'`{field_name}`', mysql_type]
|
|
|
|
# NULL/NOT NULL
|
|
if field.null:
|
|
parts.append('NULL')
|
|
else:
|
|
parts.append('NOT NULL')
|
|
|
|
# DEFAULT
|
|
if field.default is not None:
|
|
default_val = field.default
|
|
if isinstance(default_val, str):
|
|
escaped = default_val.replace("'", "''")
|
|
parts.append(f"DEFAULT '{escaped}'")
|
|
elif isinstance(default_val, bool):
|
|
parts.append(f"DEFAULT {1 if default_val else 0}")
|
|
elif isinstance(default_val, (int, float)):
|
|
parts.append(f"DEFAULT {default_val}")
|
|
elif isinstance(default_val, dict) or isinstance(default_val, list):
|
|
import json
|
|
escaped = json.dumps(default_val).replace("'", "''")
|
|
parts.append(f"DEFAULT '{escaped}'")
|
|
|
|
# COMMENT
|
|
if hasattr(field, 'help_text') and field.help_text:
|
|
escaped = field.help_text.replace("'", "''")
|
|
parts.append(f"COMMENT '{escaped}'")
|
|
|
|
sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}"
|
|
|
|
# Add index if needed
|
|
index_sql = None
|
|
if getattr(field, 'index', False):
|
|
index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)"
|
|
|
|
return sql, index_sql
|
|
|
|
|
|
def generate_drop_field_sql(table_name: str, field_name: str) -> str:
|
|
"""Generate SQL for dropping a field from a table."""
|
|
return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`"
|
|
|
|
|
|
def generate_rollback_field_sql(table_name: str, field_name: str) -> str:
|
|
"""Generate SQL for removing a field."""
|
|
return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`"
|
|
|
|
|
|
def generate_rollback_add_field_sql(table_name: str, col_info: dict, field_name: str) -> str:
|
|
"""Generate SQL for rolling back a dropped field (re-adding it).
|
|
|
|
This reconstructs the ADD COLUMN statement from the column info
|
|
that was captured before the field was dropped.
|
|
"""
|
|
mysql_type = col_info.get('column_type', 'LONGTEXT')
|
|
|
|
parts = [f'`{field_name}`', mysql_type]
|
|
|
|
# NULL/NOT NULL
|
|
if col_info.get('nullable', True):
|
|
parts.append('NULL')
|
|
else:
|
|
parts.append('NOT NULL')
|
|
|
|
# DEFAULT
|
|
default_val = col_info.get('default')
|
|
if default_val is not None:
|
|
if isinstance(default_val, str):
|
|
escaped = default_val.replace("'", "''")
|
|
parts.append(f"DEFAULT '{escaped}'")
|
|
elif isinstance(default_val, bool):
|
|
parts.append(f"DEFAULT {1 if default_val else 0}")
|
|
elif isinstance(default_val, (int, float)):
|
|
parts.append(f"DEFAULT {default_val}")
|
|
|
|
sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}"
|
|
|
|
# Re-add index if it was a non-primary key
|
|
index_sql = None
|
|
if col_info.get('column_key') == 'MUL':
|
|
index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)"
|
|
|
|
return sql, index_sql
|
|
|
|
|
|
def generate_rollback_modify_sql(table_name: str, old_info: dict, field_name: str) -> str:
|
|
"""Generate SQL for rolling back a field type change.
|
|
|
|
Note: This restores the column type, but data values may need manual handling
|
|
if the type conversion caused data loss or transformation.
|
|
"""
|
|
# Reconstruct MySQL type from old_info
|
|
mysql_type = old_info.get('column_type', 'LONGTEXT')
|
|
|
|
# Build column definition
|
|
parts = [f'`{field_name}`', mysql_type]
|
|
|
|
# NULL/NOT NULL
|
|
if old_info.get('nullable', True):
|
|
parts.append('NULL')
|
|
else:
|
|
parts.append('NOT NULL')
|
|
|
|
# DEFAULT (if available)
|
|
if old_info.get('default') is not None:
|
|
default_val = old_info['default']
|
|
if isinstance(default_val, str):
|
|
escaped = default_val.replace("'", "''")
|
|
parts.append(f"DEFAULT '{escaped}'")
|
|
elif isinstance(default_val, bool):
|
|
parts.append(f"DEFAULT {1 if default_val else 0}")
|
|
elif isinstance(default_val, (int, float)):
|
|
parts.append(f"DEFAULT {default_val}")
|
|
|
|
return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}"
|
|
|
|
|
|
def generate_modify_field_sql(table_name: str, field: Field, field_name: str) -> str:
|
|
"""Generate SQL for modifying a field in MySQL table."""
|
|
field_class = field.__class__.__name__
|
|
|
|
# Determine MySQL column type
|
|
mysql_type_map = {
|
|
'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
|
|
'TextField': 'LONGTEXT',
|
|
'LongTextField': 'LONGTEXT',
|
|
'JSONField': 'LONGTEXT',
|
|
'ListField': 'LONGTEXT',
|
|
'SerializedField': 'LONGTEXT',
|
|
'IntegerField': 'INT',
|
|
'BigIntegerField': 'BIGINT',
|
|
'FloatField': 'DOUBLE',
|
|
'BooleanField': 'TINYINT(1)',
|
|
'DateTimeField': 'DATETIME',
|
|
'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)',
|
|
}
|
|
|
|
mysql_type = mysql_type_map.get(field_class, 'LONGTEXT')
|
|
|
|
# Build column definition
|
|
parts = [f'`{field_name}`', mysql_type]
|
|
|
|
# NULL/NOT NULL
|
|
if field.null:
|
|
parts.append('NULL')
|
|
else:
|
|
parts.append('NOT NULL')
|
|
|
|
# DEFAULT
|
|
if field.default is not None:
|
|
default_val = field.default
|
|
if isinstance(default_val, str):
|
|
escaped = default_val.replace("'", "''")
|
|
parts.append(f"DEFAULT '{escaped}'")
|
|
elif isinstance(default_val, bool):
|
|
parts.append(f"DEFAULT {1 if default_val else 0}")
|
|
elif isinstance(default_val, (int, float)):
|
|
parts.append(f"DEFAULT {default_val}")
|
|
elif isinstance(default_val, dict) or isinstance(default_val, list):
|
|
import json
|
|
escaped = json.dumps(default_val).replace("'", "''")
|
|
parts.append(f"DEFAULT '{escaped}'")
|
|
|
|
# COMMENT
|
|
if hasattr(field, 'help_text') and field.help_text:
|
|
escaped = field.help_text.replace("'", "''")
|
|
parts.append(f"COMMENT '{escaped}'")
|
|
|
|
return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}"
|
|
|
|
|
|
def generate_migration_content(new_tables: list, field_changes: dict, migrate_dir: str, migration_name: str, drop_fields: bool = False) -> str:
|
|
"""Generate migration file content"""
|
|
lines = [
|
|
'"""Peewee migrations."""',
|
|
'',
|
|
'from contextlib import suppress',
|
|
'',
|
|
'import peewee as pw',
|
|
'from peewee_migrate import Migrator',
|
|
'',
|
|
'',
|
|
'with suppress(ImportError):',
|
|
' import playhouse.postgres_ext as pw_pext',
|
|
'',
|
|
'',
|
|
'def migrate(migrator: Migrator, database: pw.Database, *, fake=False):',
|
|
' """Write your migrations here."""',
|
|
'',
|
|
]
|
|
|
|
# Generate create_model for new tables
|
|
for model in new_tables:
|
|
table_name = model._meta.table_name
|
|
model_name = model.__name__
|
|
|
|
lines.append(' @migrator.create_model')
|
|
lines.append(f' class {model_name}(pw.Model):')
|
|
|
|
# Get all fields
|
|
fields = model._meta.fields
|
|
for field_name, field in fields.items():
|
|
field_code = generate_field_code(field, field_name)
|
|
lines.append(f' {field_name} = {field_code}')
|
|
|
|
lines.append('')
|
|
lines.append(' class Meta:')
|
|
lines.append(f' table_name = "{table_name}"')
|
|
|
|
# Add indexes if defined
|
|
indexes = getattr(model._meta, 'indexes', None)
|
|
if indexes:
|
|
lines.append(f' indexes = {indexes}')
|
|
|
|
lines.append('')
|
|
|
|
# Generate SQL for adding new fields to existing tables
|
|
for table_name, changes in field_changes.items():
|
|
if changes.get('added'):
|
|
for field_name, field in changes['added'].items():
|
|
sql, index_sql = generate_add_field_sql(table_name, field, field_name)
|
|
lines.append(f' migrator.sql("{sql}")')
|
|
if index_sql:
|
|
lines.append(f' migrator.sql("{index_sql}")')
|
|
lines.append('')
|
|
|
|
# Generate SQL for modifying fields in existing tables
|
|
for table_name, changes in field_changes.items():
|
|
if changes.get('changed'):
|
|
for field_name, (old_info, field) in changes['changed'].items():
|
|
modify_sql = generate_modify_field_sql(table_name, field, field_name)
|
|
lines.append(f' migrator.sql("{modify_sql}")')
|
|
lines.append('')
|
|
|
|
# Generate SQL for dropping removed fields from existing tables
|
|
if drop_fields:
|
|
for table_name, changes in field_changes.items():
|
|
if changes.get('removed'):
|
|
for field_name, col_info in changes['removed'].items():
|
|
drop_sql = generate_drop_field_sql(table_name, field_name)
|
|
lines.append(f' # WARNING: Dropping column `{field_name}` from `{table_name}` - this will permanently delete data!')
|
|
lines.append(f' migrator.sql("{drop_sql}")')
|
|
lines.append('')
|
|
|
|
# Generate rollback
|
|
lines.append('')
|
|
lines.append('def rollback(migrator: Migrator, database: pw.Database, *, fake=False):')
|
|
lines.append(' """Write your rollback migrations here."""')
|
|
lines.append('')
|
|
|
|
# Rollback: re-add dropped fields (before other rollbacks, since they may depend on these fields)
|
|
if drop_fields:
|
|
for table_name, changes in field_changes.items():
|
|
if changes.get('removed'):
|
|
for field_name, col_info in changes['removed'].items():
|
|
add_sql, index_sql = generate_rollback_add_field_sql(table_name, col_info, field_name)
|
|
lines.append(f' # Re-add dropped column `{field_name}` to `{table_name}` (data is lost)')
|
|
lines.append(f' migrator.sql("{add_sql}")')
|
|
if index_sql:
|
|
lines.append(f' migrator.sql("{index_sql}")')
|
|
|
|
# Rollback: reverse field type changes first (before removing added fields)
|
|
for table_name, changes in field_changes.items():
|
|
if changes.get('changed'):
|
|
for field_name, (old_info, field) in changes['changed'].items():
|
|
rollback_modify_sql = generate_rollback_modify_sql(table_name, old_info, field_name)
|
|
lines.append(' # Note: Data values may need manual handling if type conversion caused data loss')
|
|
lines.append(f' migrator.sql("{rollback_modify_sql}")')
|
|
|
|
# Rollback: remove added fields using SQL
|
|
for table_name, changes in field_changes.items():
|
|
if changes.get('added'):
|
|
for field_name in changes['added'].keys():
|
|
rollback_sql = generate_rollback_field_sql(table_name, field_name)
|
|
lines.append(f' migrator.sql("{rollback_sql}")')
|
|
|
|
# Rollback: remove tables (in reverse order)
|
|
for model in reversed(new_tables):
|
|
table_name = model._meta.table_name
|
|
lines.append(f' migrator.remove_model("{table_name}")')
|
|
|
|
lines.append('')
|
|
|
|
return '\n'.join(lines)
|
|
|
|
|
|
def create_migration(router: Router, models: list, db, name: str = "auto", drop_fields: bool = False):
|
|
"""Create a new migration by auto-detecting model changes
|
|
|
|
Detects:
|
|
1. New tables -> generate create_model
|
|
2. New fields in existing tables -> generate add_fields
|
|
3. Field type changes -> generate change_fields
|
|
4. Removed fields (only when --drop is specified) -> generate drop_fields
|
|
|
|
Args:
|
|
router: peewee-migrate Router instance
|
|
models: List of model classes to compare against database
|
|
db: Database connection
|
|
name: Migration name
|
|
drop_fields: Whether to include DROP COLUMN for removed fields
|
|
"""
|
|
try:
|
|
# Get existing tables from database
|
|
cursor = db.execute_sql(
|
|
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s",
|
|
(db.database,)
|
|
)
|
|
existing_tables = {row[0] for row in cursor.fetchall()}
|
|
|
|
new_tables = []
|
|
field_changes = {}
|
|
|
|
for model in models:
|
|
table_name = model._meta.table_name
|
|
|
|
if table_name not in existing_tables:
|
|
# New table
|
|
new_tables.append(model)
|
|
logger.info(f"New table detected: {table_name}")
|
|
else:
|
|
# Existing table - check for field changes
|
|
logger.info(f"Checking existing table: {table_name}")
|
|
|
|
# Get model fields (exclude auto-generated)
|
|
model_fields = {}
|
|
for field_name, field in model._meta.fields.items():
|
|
# Skip id and base model fields
|
|
if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'):
|
|
continue
|
|
if hasattr(field, '_auto_created') and field._auto_created:
|
|
continue
|
|
model_fields[field_name] = field
|
|
|
|
# Get database columns
|
|
db_columns = get_table_columns(db, table_name)
|
|
|
|
# Compare
|
|
changes = compare_fields(model_fields, db_columns)
|
|
|
|
if changes['added'] or changes['changed'] or changes['removed']:
|
|
field_changes[table_name] = changes
|
|
|
|
# Check if any changes detected
|
|
has_removed = any(changes.get('removed') for changes in field_changes.values())
|
|
if not drop_fields and has_removed:
|
|
removed_details = []
|
|
for table_name, changes in field_changes.items():
|
|
if changes.get('removed'):
|
|
for col_name in changes['removed']:
|
|
removed_details.append(f"{table_name}.{col_name}")
|
|
logger.warning(f"Removed fields detected (not included in migration, use --drop to include): {', '.join(removed_details)}")
|
|
# Remove 'removed' from changes since we're not acting on them
|
|
for table_name in field_changes:
|
|
field_changes[table_name]['removed'] = {}
|
|
|
|
if not new_tables and not any(changes['added'] or changes['changed'] for changes in field_changes.values()):
|
|
if not (drop_fields and has_removed):
|
|
logger.info("No schema changes detected, migration not created")
|
|
return None
|
|
|
|
# Generate migration file content
|
|
migration_content = generate_migration_content(new_tables, field_changes, router.migrate_dir, name, drop_fields=drop_fields)
|
|
|
|
# Get next migration number (count existing migration files)
|
|
existing_migrations = [f for f in os.listdir(router.migrate_dir) if f.endswith('.py') and not f.startswith('_')]
|
|
migration_num = len(existing_migrations) + 1
|
|
migration_file = os.path.join(router.migrate_dir, f'{migration_num:03d}_{name}.py')
|
|
|
|
with open(migration_file, 'w') as f:
|
|
f.write(migration_content)
|
|
|
|
logger.info(f"Created migration: {migration_file}")
|
|
return migration_file
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create migration: {e}")
|
|
raise
|
|
|
|
|
|
def run_migrations(router: Router):
|
|
"""Run all pending migrations"""
|
|
try:
|
|
diff = router.diff
|
|
if not diff:
|
|
logger.info("No pending migrations to run")
|
|
return
|
|
|
|
router.run()
|
|
logger.info("Migrations completed successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to run migrations: {e}")
|
|
raise
|
|
|
|
|
|
def list_migrations(router: Router):
|
|
"""List all migrations"""
|
|
todo = router.todo
|
|
if not todo:
|
|
logger.info("No migration files found")
|
|
return
|
|
|
|
logger.info("Available migrations:")
|
|
done = set(router.done)
|
|
for migration in todo:
|
|
status = "applied" if migration in done else "pending"
|
|
logger.info(f" [{status}] {migration}")
|
|
|
|
|
|
def diff_schema(models: list, db):
|
|
"""Show schema differences between models and database"""
|
|
logger.info("Checking schema differences...")
|
|
|
|
# Tables to ignore (managed by peewee-migrate)
|
|
IGNORE_TABLES = {'migratehistory'}
|
|
|
|
# Get all model table names
|
|
model_tables = set()
|
|
for model in models:
|
|
table_name = model._meta.table_name
|
|
model_tables.add(table_name)
|
|
|
|
logger.info(f"Found {len(model_tables)} model tables")
|
|
|
|
# Get existing tables from database
|
|
cursor = db.execute_sql(
|
|
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s",
|
|
(db.database,)
|
|
)
|
|
existing_tables = {row[0] for row in cursor.fetchall() if row[0] not in IGNORE_TABLES}
|
|
|
|
# Find tables that exist in models but not in database
|
|
missing_tables = model_tables - existing_tables
|
|
if missing_tables:
|
|
logger.warning(f"Tables not in database ({len(missing_tables)}): {', '.join(sorted(missing_tables))}")
|
|
|
|
# Find tables that exist in database but not in models
|
|
extra_tables = existing_tables - model_tables
|
|
if extra_tables:
|
|
logger.info(f"Tables in database but not in models: {', '.join(sorted(extra_tables))}")
|
|
|
|
# Check field differences for existing tables
|
|
common_tables = model_tables & existing_tables
|
|
if common_tables:
|
|
logger.info(f"\nChecking field differences for {len(common_tables)} existing tables...")
|
|
|
|
total_added = 0
|
|
total_changed = 0
|
|
total_removed = 0
|
|
|
|
for model in models:
|
|
table_name = model._meta.table_name
|
|
if table_name not in common_tables:
|
|
continue
|
|
|
|
# Get model fields
|
|
model_fields = {}
|
|
for field_name, field in model._meta.fields.items():
|
|
if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'):
|
|
continue
|
|
model_fields[field_name] = field
|
|
|
|
# Get database columns
|
|
db_columns = get_table_columns(db, table_name)
|
|
|
|
# Compare
|
|
changes = compare_fields(model_fields, db_columns)
|
|
|
|
if changes['added']:
|
|
total_added += len(changes['added'])
|
|
field_details = [f"{k}:{v.__class__.__name__}" for k, v in changes['added'].items()]
|
|
logger.info(f" {table_name}: {len(changes['added'])} new field(s) - {field_details}")
|
|
|
|
if changes['changed']:
|
|
total_changed += len(changes['changed'])
|
|
field_details = [f"{k}:{v[1].__class__.__name__}" for k, v in changes['changed'].items()]
|
|
logger.info(f" {table_name}: {len(changes['changed'])} changed field(s) - {field_details}")
|
|
|
|
if changes['removed']:
|
|
total_removed += len(changes['removed'])
|
|
field_details = [f"{k}:{v['column_type']}" for k, v in changes['removed'].items()]
|
|
logger.warning(f" {table_name}: {len(changes['removed'])} removed field(s) - {field_details}")
|
|
|
|
logger.info(f"\nSummary: {total_added} new fields, {total_changed} changed fields, {total_removed} removed fields")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Database Schema Synchronization Tool using peewee-migrate',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# List all migrations
|
|
python db_schema_sync.py --list --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
|
|
|
|
# Create migration from model changes
|
|
python db_schema_sync.py --create --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
|
|
|
|
# Create migration including dropped fields (destructive!)
|
|
python db_schema_sync.py --create --drop --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
|
|
|
|
# Run all pending migrations
|
|
python db_schema_sync.py --migrate --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
|
|
|
|
# Show schema differences
|
|
python db_schema_sync.py --diff --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.4
|
|
"""
|
|
)
|
|
|
|
# Database connection options
|
|
parser.add_argument('--host', type=str, required=True, help='MySQL host')
|
|
parser.add_argument('--port', type=int, default=3306, help='MySQL port (default: 3306)')
|
|
parser.add_argument('--user', type=str, required=True, help='MySQL user')
|
|
parser.add_argument('--password', type=str, required=True, help='MySQL password')
|
|
parser.add_argument('--database', type=str, required=True, help='MySQL database name')
|
|
|
|
# Version option
|
|
parser.add_argument('--version', '-v', type=str, required=True,
|
|
help='Version number in format vxx.xx.xx (e.g., v0.25.4)')
|
|
|
|
# Action options
|
|
parser.add_argument('--list', '-l', action='store_true', help='List all migrations')
|
|
parser.add_argument('--create', '-c', action='store_true',
|
|
help='Create migration from model changes (auto-detect)')
|
|
parser.add_argument('--migrate', '-m', action='store_true', help='Run pending migrations')
|
|
parser.add_argument('--diff', '-d', action='store_true', help='Show schema differences')
|
|
|
|
# Migration options
|
|
parser.add_argument('--name', '-n', type=str, default='auto', help='Migration name')
|
|
parser.add_argument('--drop', action='store_true',
|
|
help='Include DROP COLUMN for fields removed from models (destructive - will permanently delete data!)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate version format
|
|
if not validate_version(args.version):
|
|
logger.error(f"Invalid version format: {args.version}. Expected format: vxx.xx.xx (e.g., v0.25.4)")
|
|
sys.exit(1)
|
|
|
|
# Validate at least one action is specified
|
|
if not any([args.list, args.create, args.migrate, args.diff]):
|
|
parser.print_help()
|
|
logger.error("Please specify at least one action: --list, --create, --migrate, or --diff")
|
|
sys.exit(1)
|
|
|
|
# Convert version to directory name
|
|
version_dir = version_to_dirname(args.version)
|
|
migrate_dir = os.path.join(PROJECT_BASE, 'tools', 'migrate', version_dir)
|
|
|
|
logger.info(f"Version: {args.version}")
|
|
logger.info(f"Migration directory: {migrate_dir}")
|
|
|
|
# Create migration directory if it doesn't exist
|
|
os.makedirs(migrate_dir, exist_ok=True)
|
|
|
|
# Load database models
|
|
logger.info("Loading database models from api/db/db_models.py...")
|
|
models, _ = load_db_models()
|
|
logger.info(f"Found {len(models)} model classes")
|
|
|
|
# Create database connection
|
|
db = create_database_connection(
|
|
host=args.host,
|
|
port=args.port,
|
|
user=args.user,
|
|
password=args.password,
|
|
database=args.database
|
|
)
|
|
|
|
try:
|
|
db.connect()
|
|
logger.info(f"Connected to database: {args.database}")
|
|
|
|
# Create router
|
|
router = Router(
|
|
db,
|
|
migrate_dir,
|
|
ignore=['basemodel', 'base_model', 'migratehistory']
|
|
)
|
|
|
|
# Execute requested actions
|
|
if args.list:
|
|
list_migrations(router)
|
|
|
|
if args.create:
|
|
create_migration(router, models, db, args.name, drop_fields=args.drop)
|
|
|
|
if args.migrate:
|
|
run_migrations(router)
|
|
|
|
if args.diff:
|
|
diff_schema(models, db)
|
|
|
|
finally:
|
|
if not db.is_closed():
|
|
db.close()
|
|
logger.info("Database connection closed")
|
|
|
|
logger.info("Done.")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |