diff --git a/api/db/db_models.py b/api/db/db_models.py index ca72be210..8ba327838 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -27,7 +27,23 @@ from functools import wraps from quart_auth import AuthUser from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer -from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField +from peewee import ( + fn, + InterfaceError, + OperationalError, + ProgrammingError, + BigIntegerField, + BooleanField, + CharField, + CompositeKey, + DateTimeField, + Field, + FloatField, + IntegerField, + Metadata, + Model, + TextField, +) from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase @@ -692,7 +708,7 @@ class User(DataBaseModel, AuthUser): access_token = CharField(max_length=255, null=True, index=True) nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True) password = CharField(max_length=255, null=True, help_text="password", index=True) - email = CharField(max_length=255, null=False, help_text="email", index=True) + email = CharField(max_length=255, null=False, help_text="email", unique=True) avatar = TextField(null=True, help_text="avatar base64 string") language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", index=True) color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright", index=True) @@ -1332,6 +1348,42 @@ def alter_db_rename_column(migrator, table_name, old_column_name, new_column_nam # logging.critical(f"Failed to rename {settings.DATABASE_TYPE.upper()}.{table_name} column {old_column_name} to {new_column_name}, error: {ex}") pass +def migrate_add_unique_email(migrator): + """Deduplicates user emails and add UNIQUE constraint to email column (idempotent)""" + # step 1: rename duplicate rows so the UNIQUE constraint can be applied + try: + duplicates = User.select(User.email).group_by(User.email).having(fn.COUNT(User.id) > 1).tuples() + for (dup_email,) in duplicates: + # Keep the superuser row, or the oldest row if there is no superuser + rows = list( + User + .select(User.id) + .where(User.email == dup_email) + .order_by(User.is_superuser.desc(), User.create_time.asc()) + .tuples() + ) + for (uid,) in rows[1:]: + new_email = f"{dup_email}_DUPLICATE_{uid[:8]}" + User.update(email=new_email).where(User.id == uid).execute() + logging.warning("Renamed duplicate user %s email to %s during migration", uid, new_email) + except Exception as ex: + logging.critical("Failed to deduplicate user.email before adding UNIQUE constraint: %s", ex) + return + + # step 2: add UNIQUE index via migrator + try: + migrate(migrator.add_index("user", ("email",), unique=True)) + except (OperationalError, ProgrammingError) as ex: + msg = str(ex) + # MySQL 1061 "Duplicate key name" or PostgreSQL "already exists" -> already migrated + if "1061" in msg or "Duplicate key name" in msg or "already exists" in msg.lower(): + pass + else: + logging.critical("Failed to add UNIQUE constraint on user.email: %s", ex) + except Exception as ex: + logging.critical("Failed to add UNIQUE constraint on user.email: %s", ex) + + def migrate_db(): logging.disable(logging.ERROR) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) @@ -1383,3 +1435,5 @@ def migrate_db(): # Migrate system_settings.value from CharField to TextField for longer sandbox configs alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)")) logging.disable(logging.NOTSET) + # this is after re-enabling logging to allow logging changed user emails + migrate_add_unique_email(migrator) diff --git a/api/db/init_data.py b/api/db/init_data.py index 291f39fe5..525ae5bc5 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -21,6 +21,8 @@ import time import uuid from copy import deepcopy +from peewee import IntegrityError + from api.db import UserTenantRole from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM from api.db.services import UserService @@ -42,6 +44,10 @@ DEFAULT_SUPERUSER_EMAIL = os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io DEFAULT_SUPERUSER_PASSWORD = os.getenv("DEFAULT_SUPERUSER_PASSWORD", "admin") def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_EMAIL, password=DEFAULT_SUPERUSER_PASSWORD, role=UserTenantRole.OWNER): + if UserService.query(email=email): + logging.info("User with email %s already exists, skipping initialization.", email) + return + user_info = { "id": uuid.uuid1().hex, "password": encode_to_base64(password), @@ -70,8 +76,12 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ tenant_llm = get_init_tenant_llm(user_info["id"]) - if not UserService.save(**user_info): - logging.error("can't init admin.") + try: + if not UserService.save(**user_info): + logging.error("can't init admin.") + return + except IntegrityError: + logging.info("User with email %s already exists, skipping.", email) return TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) @@ -79,19 +89,17 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ logging.info( f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.") - chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) - msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})) - if msg.find("ERROR: ") == 0: - logging.error( - "'{}' doesn't work. {}".format( - tenant["llm_id"], - msg)) - embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) - v, c = embd_mdl.encode(["Hello!"]) - if c == 0: - logging.error( - "'{}' doesn't work!".format( - tenant["embd_id"])) + if tenant["llm_id"]: + chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) + msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})) + if msg.find("ERROR: ") == 0: + logging.error("'{}' doesn't work. {}".format( tenant["llm_id"], msg)) + + if tenant["embd_id"]: + embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) + v, c = embd_mdl.encode(["Hello!"]) + if c == 0: + logging.error("'{}' doesn't work!".format(tenant["embd_id"])) def init_llm_factory():