Files
ragflow/api/apps/services/canvas_replica_service.py
Magicbook1108 1027916bfe Fix: inconsistent state handling for multi-user single-canvas access (#13267)
### What problem does this PR solve?

<img width="700" alt="image"
src="https://github.com/user-attachments/assets/1db7412e-4554-44bc-84ba-16421949aacc"
/>

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
2026-02-28 15:09:21 +08:00

259 lines
8.4 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import random
import time
from api.db import CanvasCategory
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
class CanvasReplicaService:
"""
Manage per-user canvas runtime replicas stored in Redis.
Lifecycle:
- bootstrap: initialize/refresh replica from DB DSL
- load_for_run: read replica before run
- commit_after_run: atomically persist run result back to replica
"""
TTL_SECS = 3 * 60 * 60
REPLICA_KEY_PREFIX = "canvas:replica"
LOCK_KEY_PREFIX = "canvas:replica:lock"
LOCK_TIMEOUT_SECS = 10
LOCK_BLOCKING_TIMEOUT_SECS = 1
LOCK_RETRY_ATTEMPTS = 3
LOCK_RETRY_SLEEP_SECS = 0.2
@classmethod
def normalize_dsl(cls, dsl):
"""Normalize DSL to a JSON-serializable dict. Raise ValueError on invalid input."""
normalized = dsl
if isinstance(normalized, str):
try:
normalized = json.loads(normalized)
except Exception as e:
raise ValueError("Invalid DSL JSON string.") from e
if not isinstance(normalized, dict):
raise ValueError("DSL must be a JSON object.")
try:
return json.loads(json.dumps(normalized, ensure_ascii=False))
except Exception as e:
raise ValueError("DSL is not JSON-serializable.") from e
@classmethod
def _replica_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str:
return f"{cls.REPLICA_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}"
@classmethod
def _lock_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str:
return f"{cls.LOCK_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}"
@classmethod
def _read_payload(cls, replica_key: str):
"""Read replica payload from Redis; return None on missing/invalid content."""
cache_blob = REDIS_CONN.get(replica_key)
if not cache_blob:
return None
try:
payload = json.loads(cache_blob)
if not isinstance(payload, dict):
return None
payload["dsl"] = cls.normalize_dsl(payload.get("dsl", {}))
return payload
except Exception as e:
logging.warning("Failed to parse canvas replica %s: %s", replica_key, e)
return None
@classmethod
def _write_payload(cls, replica_key: str, payload: dict):
"""Write payload and refresh TTL."""
payload["updated_at"] = int(time.time())
REDIS_CONN.set_obj(replica_key, payload, cls.TTL_SECS)
@classmethod
def _build_payload(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
return {
"canvas_id": canvas_id,
"tenant_id": str(tenant_id),
"runtime_user_id": str(runtime_user_id),
"title": title or "",
"canvas_category": canvas_category or CanvasCategory.Agent,
"dsl": cls.normalize_dsl(dsl),
"updated_at": int(time.time()),
}
@classmethod
def create_if_absent(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""Create a runtime replica if it does not exist; otherwise keep existing state."""
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
payload = cls._read_payload(replica_key)
if payload:
return payload
payload = cls._build_payload(canvas_id, str(tenant_id), str(runtime_user_id), dsl, canvas_category, title)
cls._write_payload(replica_key, payload)
return payload
@classmethod
def bootstrap(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""Bootstrap replica by creating it when absent and keeping existing runtime state."""
return cls.create_if_absent(
canvas_id=canvas_id,
tenant_id=tenant_id,
runtime_user_id=runtime_user_id,
dsl=dsl,
canvas_category=canvas_category,
title=title,
)
@classmethod
def load_for_run(cls, canvas_id: str, tenant_id: str, runtime_user_id: str):
"""Load current runtime replica used by /completion."""
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
return cls._read_payload(replica_key)
@classmethod
def replace_for_set(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""Replace replica content for `/set` under lock."""
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
lock_key = cls._lock_key(canvas_id, str(tenant_id), str(runtime_user_id))
lock = cls._acquire_lock_with_retry(lock_key)
if not lock:
logging.error("Failed to acquire canvas replica lock after retry: %s", lock_key)
return False
try:
updated_payload = cls._build_payload(
canvas_id=canvas_id,
tenant_id=str(tenant_id),
runtime_user_id=str(runtime_user_id),
dsl=dsl,
canvas_category=canvas_category,
title=title,
)
cls._write_payload(replica_key, updated_payload)
return True
except Exception:
logging.exception("Failed to replace canvas replica from /set.")
return False
finally:
try:
lock.release()
except Exception:
logging.exception("Failed to release canvas replica lock: %s", lock_key)
@classmethod
def _acquire_lock_with_retry(cls, lock_key: str):
"""Acquire distributed lock with bounded retries; return lock object or None."""
lock = RedisDistributedLock(
lock_key,
timeout=cls.LOCK_TIMEOUT_SECS,
blocking_timeout=cls.LOCK_BLOCKING_TIMEOUT_SECS,
)
for idx in range(cls.LOCK_RETRY_ATTEMPTS):
if lock.acquire():
return lock
if idx < cls.LOCK_RETRY_ATTEMPTS - 1:
time.sleep(cls.LOCK_RETRY_SLEEP_SECS + random.uniform(0, 0.1))
return None
@classmethod
def commit_after_run(
cls,
canvas_id: str,
tenant_id: str,
runtime_user_id: str,
dsl,
canvas_category=CanvasCategory.Agent,
title="",
):
"""
Commit post-run DSL into replica.
Returns:
bool: True on committed/saved, False on commit failure.
"""
new_dsl = cls.normalize_dsl(dsl)
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
try:
latest_payload = cls._read_payload(replica_key)
# Always write latest runtime DSL back to Redis first.
updated_payload = cls._build_payload(
canvas_id=canvas_id,
tenant_id=str(tenant_id),
runtime_user_id=str(runtime_user_id),
dsl=new_dsl,
canvas_category=canvas_category if not latest_payload else (canvas_category or latest_payload.get("canvas_category", CanvasCategory.Agent)),
title=title if not latest_payload else (title or latest_payload.get("title", "")),
)
cls._write_payload(replica_key, updated_payload)
return True
except Exception:
logging.exception("Failed to commit canvas runtime replica.")
return False