mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-06 02:07:49 +08:00
Fix: init superuser can create duplicate users (#13221)
### What problem does this PR solve? This PR fixes 2 bugs related to RAGFlow's init superuser functionality. #### Bug 1 When the RAGFlow server was started with the `--init-superuser` option it would always create a new admin user even if it already exists resulting in duplicate users. To fix this, I added an additional check before create the superuser and added the *unique* constraint to the email column of the database, to mitigate potential TOCTOU race conditions. Since existing databases could contain duplicate emails I added email de-duplication to the database migration. #### Bug 2 When the RAGFlow server was started with the `--init-superuser` option but without configured default LLM and embedding models it would fail to start because the `init_superuser` function would always make test request to the models even if they were not set. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -27,7 +27,23 @@ from functools import wraps
|
||||
|
||||
from quart_auth import AuthUser
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from peewee import (
|
||||
fn,
|
||||
InterfaceError,
|
||||
OperationalError,
|
||||
ProgrammingError,
|
||||
BigIntegerField,
|
||||
BooleanField,
|
||||
CharField,
|
||||
CompositeKey,
|
||||
DateTimeField,
|
||||
Field,
|
||||
FloatField,
|
||||
IntegerField,
|
||||
Metadata,
|
||||
Model,
|
||||
TextField,
|
||||
)
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
|
||||
@ -692,7 +708,7 @@ class User(DataBaseModel, AuthUser):
|
||||
access_token = CharField(max_length=255, null=True, index=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
|
||||
password = CharField(max_length=255, null=True, help_text="password", index=True)
|
||||
email = CharField(max_length=255, null=False, help_text="email", index=True)
|
||||
email = CharField(max_length=255, null=False, help_text="email", unique=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", index=True)
|
||||
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright", index=True)
|
||||
@ -1332,6 +1348,42 @@ def alter_db_rename_column(migrator, table_name, old_column_name, new_column_nam
|
||||
# logging.critical(f"Failed to rename {settings.DATABASE_TYPE.upper()}.{table_name} column {old_column_name} to {new_column_name}, error: {ex}")
|
||||
pass
|
||||
|
||||
def migrate_add_unique_email(migrator):
|
||||
"""Deduplicates user emails and add UNIQUE constraint to email column (idempotent)"""
|
||||
# step 1: rename duplicate rows so the UNIQUE constraint can be applied
|
||||
try:
|
||||
duplicates = User.select(User.email).group_by(User.email).having(fn.COUNT(User.id) > 1).tuples()
|
||||
for (dup_email,) in duplicates:
|
||||
# Keep the superuser row, or the oldest row if there is no superuser
|
||||
rows = list(
|
||||
User
|
||||
.select(User.id)
|
||||
.where(User.email == dup_email)
|
||||
.order_by(User.is_superuser.desc(), User.create_time.asc())
|
||||
.tuples()
|
||||
)
|
||||
for (uid,) in rows[1:]:
|
||||
new_email = f"{dup_email}_DUPLICATE_{uid[:8]}"
|
||||
User.update(email=new_email).where(User.id == uid).execute()
|
||||
logging.warning("Renamed duplicate user %s email to %s during migration", uid, new_email)
|
||||
except Exception as ex:
|
||||
logging.critical("Failed to deduplicate user.email before adding UNIQUE constraint: %s", ex)
|
||||
return
|
||||
|
||||
# step 2: add UNIQUE index via migrator
|
||||
try:
|
||||
migrate(migrator.add_index("user", ("email",), unique=True))
|
||||
except (OperationalError, ProgrammingError) as ex:
|
||||
msg = str(ex)
|
||||
# MySQL 1061 "Duplicate key name" or PostgreSQL "already exists" -> already migrated
|
||||
if "1061" in msg or "Duplicate key name" in msg or "already exists" in msg.lower():
|
||||
pass
|
||||
else:
|
||||
logging.critical("Failed to add UNIQUE constraint on user.email: %s", ex)
|
||||
except Exception as ex:
|
||||
logging.critical("Failed to add UNIQUE constraint on user.email: %s", ex)
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1383,3 +1435,5 @@ def migrate_db():
|
||||
# Migrate system_settings.value from CharField to TextField for longer sandbox configs
|
||||
alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)"))
|
||||
logging.disable(logging.NOTSET)
|
||||
# this is after re-enabling logging to allow logging changed user emails
|
||||
migrate_add_unique_email(migrator)
|
||||
|
||||
@ -21,6 +21,8 @@ import time
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
|
||||
from peewee import IntegrityError
|
||||
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
||||
from api.db.services import UserService
|
||||
@ -42,6 +44,10 @@ DEFAULT_SUPERUSER_EMAIL = os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io
|
||||
DEFAULT_SUPERUSER_PASSWORD = os.getenv("DEFAULT_SUPERUSER_PASSWORD", "admin")
|
||||
|
||||
def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_EMAIL, password=DEFAULT_SUPERUSER_PASSWORD, role=UserTenantRole.OWNER):
|
||||
if UserService.query(email=email):
|
||||
logging.info("User with email %s already exists, skipping initialization.", email)
|
||||
return
|
||||
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": encode_to_base64(password),
|
||||
@ -70,8 +76,12 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
||||
|
||||
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
logging.error("can't init admin.")
|
||||
try:
|
||||
if not UserService.save(**user_info):
|
||||
logging.error("can't init admin.")
|
||||
return
|
||||
except IntegrityError:
|
||||
logging.info("User with email %s already exists, skipping.", email)
|
||||
return
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
@ -79,19 +89,17 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
||||
logging.info(
|
||||
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
|
||||
if msg.find("ERROR: ") == 0:
|
||||
logging.error(
|
||||
"'{}' doesn't work. {}".format(
|
||||
tenant["llm_id"],
|
||||
msg))
|
||||
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
|
||||
v, c = embd_mdl.encode(["Hello!"])
|
||||
if c == 0:
|
||||
logging.error(
|
||||
"'{}' doesn't work!".format(
|
||||
tenant["embd_id"]))
|
||||
if tenant["llm_id"]:
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
|
||||
if msg.find("ERROR: ") == 0:
|
||||
logging.error("'{}' doesn't work. {}".format( tenant["llm_id"], msg))
|
||||
|
||||
if tenant["embd_id"]:
|
||||
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
|
||||
v, c = embd_mdl.encode(["Hello!"])
|
||||
if c == 0:
|
||||
logging.error("'{}' doesn't work!".format(tenant["embd_id"]))
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
|
||||
Reference in New Issue
Block a user