refactor: implement tenant self queue for rag tasks

This commit is contained in:
hj24
2025-10-28 14:20:43 +08:00
parent 4a797ab2d8
commit 2c2b3092f6
24 changed files with 3667 additions and 92 deletions

View File

@ -41,18 +41,14 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@ -248,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
# return batch, dataset, documents
return {
"batch": batch,

View File

@ -0,0 +1,13 @@
from dataclasses import dataclass
@dataclass
class DocumentTask:
"""Document task entity for document indexing operations.
This class represents a document indexing task that can be queued
and processed by the document indexing system.
"""
tenant_id: str
dataset_id: str
document_ids: list[str]

View File

View File

@ -0,0 +1,92 @@
import json
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from extensions.ext_redis import redis_client
T = TypeVar('T')
TASK_WRAPPER_PREFIX = "__WRAPPER__:"
@dataclass
class TaskWrapper:
data: Any
def serialize(self) -> str:
return json.dumps(self.data, ensure_ascii=False)
@classmethod
def deserialize(cls, serialized_data: str) -> 'TaskWrapper':
data = json.loads(serialized_data)
return cls(data)
class TenantSelfTaskQueue(Generic[T]):
"""
Simple queue for tenant self tasks, used for tenant self task isolation.
It uses Redis list to store tasks, and Redis key to store task waiting flag.
Support tasks that can be serialized by json.
"""
DEFAULT_TASK_TTL = 60 * 60
def __init__(self, tenant_id: str, unique_key: str):
self.tenant_id = tenant_id
self.unique_key = unique_key
self.queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
self.task_key = f"tenant_{unique_key}_task:{tenant_id}"
def get_task_key(self):
return redis_client.get(self.task_key)
def set_task_waiting_time(self, ttl: int | None = None):
ttl = ttl or self.DEFAULT_TASK_TTL
redis_client.setex(self.task_key, ttl, 1)
def delete_task_key(self):
redis_client.delete(self.task_key)
def push_tasks(self, tasks: list[T]):
serialized_tasks = []
for task in tasks:
# Store str list directly, maintaining full compatibility for pipeline scenarios
if isinstance(task, str):
serialized_tasks.append(task)
else:
# Use TaskWrapper to do JSON serialization, add prefix for identification
wrapper = TaskWrapper(task)
serialized_data = wrapper.serialize()
serialized_tasks.append(f"{TASK_WRAPPER_PREFIX}{serialized_data}")
redis_client.lpush(self.queue, *serialized_tasks)
def pull_tasks(self, count: int = 1) -> list[T]:
if count <= 0:
return []
tasks = []
for _ in range(count):
serialized_task = redis_client.rpop(self.queue)
if not serialized_task:
break
if isinstance(serialized_task, bytes):
serialized_task = serialized_task.decode('utf-8')
# Check if use TaskWrapper or not
if serialized_task.startswith(TASK_WRAPPER_PREFIX):
try:
wrapper_data = serialized_task[len(TASK_WRAPPER_PREFIX):]
wrapper = TaskWrapper.deserialize(wrapper_data)
tasks.append(wrapper.data)
except (json.JSONDecodeError, TypeError, ValueError):
tasks.append(serialized_task)
else:
tasks.append(serialized_task)
return tasks
def get_next_task(self) -> T | None:
tasks = self.pull_tasks(1)
return tasks[0] if tasks else None