Files
ragflow/api/apps/__init__.py
jony376 46897d6fa4 Fix: bind memory message user_id to authenticated user for JWT auth (#14745)
### Related issues

Closes #14744

### What problem does this PR solve?

The Memory REST endpoint `POST /api/v1/messages` previously persisted
whatever `user_id` the client sent in the JSON body. Memory rows were
therefore attributed to an arbitrary string, even when the caller
authenticated as a normal workspace user via JWT (browser/session-style
bearer token decoded into an access token). That broke attribution and
audit semantics for shared memories (team visibility): any authorized
writer could spoof another subject id.

The Python SDK already sends an optional `user_id` for integrations
using **API keys** (`APIToken`) to tag an external subject distinct from
the tenant owner user.

### Solution

- Record **`g.auth_via_api_token`** in `_load_user`
(`api/apps/__init__.py`): set `True` only when authentication resolves
via `APIToken`, otherwise `False` after JWT-based login succeeds.
- In **`POST /messages`** (`memory_api.add_message`): if the request was
authenticated with an API key, keep accepting optional `user_id` from
the body (default empty string). For JWT-authenticated users, **always**
set stored `user_id` to **`current_user.id`** and ignore the client
field.
- Guard reads of `g` with **`RuntimeError`** handling so isolated
imports or tests without a Quart application context do not fail when
resolving `user_id`.
- Document on **`RAGFlow.add_message`** that `user_id` is only
meaningful for API-key authentication.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):

### Testing

- `python -m py_compile` on modified modules (`api/apps/__init__.py`,
`api/apps/restful_apis/memory_api.py`).
- Recommended: run web/SDK memory message tests (`test_add_message`,
`test_message_routes_unit`) against a full environment with `quart` and
configured services.

### Notes for reviewers

- Behavior change **only** for callers using JWT-style authorization on
`POST /messages`; API-key callers keep prior optional `user_id`
semantics.

Co-authored-by: jony376 <jony376@gmail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-11 13:26:05 +08:00

370 lines
13 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import sys
import time
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session, jsonify
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
from common.constants import StatusEnum, RetCode
from api.db.db_models import close_connection, APIToken
from api.db.services import UserService
from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands
from quart_auth import Unauthorized as QuartAuthUnauthorized
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
from quart_schema import QuartSchema
from common import settings
from api.utils.api_utils import server_error_response, get_json_result
from api.constants import API_VERSION
from common.misc_utils import get_uuid
settings.init_settings()
__all__ = ["app"]
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
def _unauthorized_message(error):
if error is None:
return UNAUTHORIZED_MESSAGE
description = getattr(error, "description", None)
if description:
return description
try:
return repr(error)
except Exception:
return UNAUTHORIZED_MESSAGE
app = Quart(__name__)
app = cors(app, allow_origin="*")
# openapi supported
QuartSchema(app)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "redis"
app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis")
app.config["MAX_CONTENT_LENGTH"] = int(
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
)
app.config['SECRET_KEY'] = settings.get_secret_key()
app.secret_key = settings.get_secret_key()
commands.register_commands(app)
from functools import wraps
from typing import ParamSpec, TypeVar
from collections.abc import Awaitable, Callable
from werkzeug.local import LocalProxy
T = TypeVar("T")
P = ParamSpec("P")
def _load_user_from_session():
"""Resolve the current user from the session cookie set by ``login_user()``.
OAuth/OIDC callbacks call ``login_user(user)`` which writes ``_user_id``
into the session. The frontend's response interceptor wipes the
Authorization header from localStorage on the first 401, so post-redirect
requests can arrive with no header at all — we still want to honour the
server-side session in that window.
The same access-token validity rules used by the JWT path are applied
here so that tokens revoked by ``logout`` (which rewrites the column to
``INVALID_<hex>``) or shortened by data corruption can't keep a stale
session authenticated.
"""
user_id = session.get("_user_id")
if not user_id:
return None
try:
users = UserService.query(id=user_id, status=StatusEnum.VALID.value)
except Exception:
logging.exception("load_user from session failed")
return None
if not users:
return None
user = users[0]
access_token = str(user.access_token or "").strip()
if not access_token or len(access_token) < 32 or access_token.startswith("INVALID_"):
return None
logging.debug("Authenticated request via session fallback for user_id=%s", user_id)
g.user = user
return user
def _load_user():
jwt = Serializer(secret_key=settings.get_secret_key())
authorization = request.headers.get("Authorization")
g.user = None
g.auth_via_api_token = False
if not authorization:
return _load_user_from_session()
# Extract auth_token based on whether Authorization starts with "bearer" (case-insensitive)
if authorization.lower().startswith("bearer "):
parts = authorization.split(maxsplit=1)
if len(parts) < 2:
logging.warning("Authorization header has invalid bearer format")
return _load_user_from_session()
auth_token = parts[1]
else:
auth_token = authorization
# Try JWT decoding
try:
access_token = str(jwt.loads(auth_token))
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return _load_user_from_session()
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return _load_user_from_session()
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return _load_user_from_session()
g.user = user[0]
return user[0]
return _load_user_from_session()
except Exception as e_jwt:
logging.warning(f"load_user from jwt got exception {e_jwt}")
# JWT decode failed, try as api_token
try:
objs = APIToken.query(token=auth_token)
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return _load_user_from_session()
g.auth_via_api_token = True
g.user = user[0]
return user[0]
logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken")
else:
logging.warning(f"load_user: No APIToken found for token={auth_token[:10]}...")
except Exception as e_api_token:
logging.warning(f"load_user from api token got exception {e_api_token}")
return _load_user_from_session()
current_user = LocalProxy(_load_user)
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
"""A decorator to restrict route access to authenticated users.
This should be used to wrap a route handler (or view function) to
enforce that only authenticated requests can access it. Note that
it is important that this decorator be wrapped by the route
decorator and not vice, versa, as below.
.. code-block:: python
@app.route('/')
@login_required
async def index():
...
If the request is not authenticated a
`quart.exceptions.Unauthorized` exception will be raised.
"""
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
user = current_user
if timing_enabled:
logging.info(
"api_timing login_required auth_ms=%.2f path=%s",
(time.perf_counter() - t_start) * 1000,
request.path,
)
if not user: # or not session.get("_user_id"):
raise QuartAuthUnauthorized()
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper
def login_user(user, remember=False, duration=None, force=False, fresh=True):
"""
Logs a user in. You should pass the actual user object to this. If the
user's `is_active` property is ``False``, they will not be logged in
unless `force` is ``True``.
This will return ``True`` if the login attempt succeeds, and ``False`` if
it fails (i.e. because the user is inactive).
:param user: The user object to log in.
:type user: object
:param remember: Whether to remember the user after their session expires.
Defaults to ``False``.
:type remember: bool
:param duration: The amount of time before the remember cookie expires. If
``None`` the value set in the settings is used. Defaults to ``None``.
:type duration: :class:`datetime.timedelta`
:param force: If the user is inactive, setting this to ``True`` will log
them in regardless. Defaults to ``False``.
:type force: bool
:param fresh: setting this to ``False`` will log in the user with a session
marked as not "fresh". Defaults to ``True``.
:type fresh: bool
"""
if not force and not user.is_active:
return False
session["_user_id"] = user.id
session["_fresh"] = fresh
session["_id"] = get_uuid()
return True
def logout_user():
"""
Logs a user out. (You do not need to pass the actual user.) This will
also clean up the remember me cookie if it exists.
"""
if "_user_id" in session:
session.pop("_user_id")
if "_fresh" in session:
session.pop("_fresh")
if "_id" in session:
session.pop("_id")
COOKIE_NAME = "remember_token"
cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
if cookie_name in request.cookies:
session["_remember"] = "clear"
if "_remember_seconds" in session:
session.pop("_remember_seconds")
return True
def search_pages_path(page_path):
app_path_list = [path for path in page_path.glob("*_app.py") if not path.name.startswith(".")]
api_path_list = [path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")]
app_path_list.extend(api_path_list)
restful_api_path_list = [path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")]
app_path_list.extend(restful_api_path_list)
return app_path_list
def register_page(page_path):
path = f"{page_path}"
page_name = page_path.stem.removesuffix("_app")
module_name = ".".join(page_path.parts[page_path.parts.index("api") : -1] + (page_name,))
spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec)
page.app = app
page.manager = Blueprint(page_name, module_name)
sys.modules[module_name] = page
spec.loader.exec_module(page)
page_name = getattr(page, "page_name", page_name)
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/"
url_prefix = f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}"
app.register_blueprint(page.manager, url_prefix=url_prefix)
return url_prefix
pages_dir = [
Path(__file__).parent,
Path(__file__).parent.parent / "api" / "apps",
Path(__file__).parent.parent / "api" / "apps" / "restful_apis",
Path(__file__).parent.parent / "api" / "apps" / "sdk",
]
client_urls_prefix = [register_page(path) for directory in pages_dir for path in search_pages_path(directory)]
# Register backward compatibility routes for deprecated APIs
from api.apps.backward_compat import register_backward_compat_routes
register_backward_compat_routes(app)
@app.errorhandler(404)
async def not_found(error):
logging.error(f"The requested URL {request.path} was not found")
message = f"Not Found: {request.path}"
response = {
"code": RetCode.NOT_FOUND,
"message": message,
"data": None,
"error": "Not Found",
}
return jsonify(response), RetCode.NOT_FOUND
@app.errorhandler(401)
async def unauthorized(error):
logging.warning("Unauthorized request")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.errorhandler(QuartAuthUnauthorized)
async def unauthorized_quart_auth(error):
logging.warning("Unauthorized request (quart_auth)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
@app.errorhandler(WerkzeugUnauthorized)
async def unauthorized_werkzeug(error):
logging.warning("Unauthorized request (werkzeug)")
return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED
@app.teardown_request
def _db_close(exception):
if exception:
logging.exception(f"Request failed: {exception}")
close_connection()