mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
refactor: implement tenant self queue for rag tasks
This commit is contained in:
@ -41,18 +41,14 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -248,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
||||
13
api/core/entities/document_task.py
Normal file
13
api/core/entities/document_task.py
Normal file
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTask:
|
||||
"""Document task entity for document indexing operations.
|
||||
|
||||
This class represents a document indexing task that can be queued
|
||||
and processed by the document indexing system.
|
||||
"""
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
document_ids: list[str]
|
||||
0
api/core/rag/pipeline/__init__.py
Normal file
0
api/core/rag/pipeline/__init__.py
Normal file
92
api/core/rag/pipeline/queue.py
Normal file
92
api/core/rag/pipeline/queue.py
Normal file
@ -0,0 +1,92 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
TASK_WRAPPER_PREFIX = "__WRAPPER__:"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskWrapper:
|
||||
data: Any
|
||||
|
||||
def serialize(self) -> str:
|
||||
return json.dumps(self.data, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_data: str) -> 'TaskWrapper':
|
||||
data = json.loads(serialized_data)
|
||||
return cls(data)
|
||||
|
||||
|
||||
class TenantSelfTaskQueue(Generic[T]):
|
||||
"""
|
||||
Simple queue for tenant self tasks, used for tenant self task isolation.
|
||||
It uses Redis list to store tasks, and Redis key to store task waiting flag.
|
||||
Support tasks that can be serialized by json.
|
||||
"""
|
||||
DEFAULT_TASK_TTL = 60 * 60
|
||||
|
||||
def __init__(self, tenant_id: str, unique_key: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.unique_key = unique_key
|
||||
self.queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
|
||||
self.task_key = f"tenant_{unique_key}_task:{tenant_id}"
|
||||
|
||||
def get_task_key(self):
|
||||
return redis_client.get(self.task_key)
|
||||
|
||||
def set_task_waiting_time(self, ttl: int | None = None):
|
||||
ttl = ttl or self.DEFAULT_TASK_TTL
|
||||
redis_client.setex(self.task_key, ttl, 1)
|
||||
|
||||
def delete_task_key(self):
|
||||
redis_client.delete(self.task_key)
|
||||
|
||||
def push_tasks(self, tasks: list[T]):
|
||||
serialized_tasks = []
|
||||
for task in tasks:
|
||||
# Store str list directly, maintaining full compatibility for pipeline scenarios
|
||||
if isinstance(task, str):
|
||||
serialized_tasks.append(task)
|
||||
else:
|
||||
# Use TaskWrapper to do JSON serialization, add prefix for identification
|
||||
wrapper = TaskWrapper(task)
|
||||
serialized_data = wrapper.serialize()
|
||||
serialized_tasks.append(f"{TASK_WRAPPER_PREFIX}{serialized_data}")
|
||||
|
||||
redis_client.lpush(self.queue, *serialized_tasks)
|
||||
|
||||
def pull_tasks(self, count: int = 1) -> list[T]:
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
tasks = []
|
||||
for _ in range(count):
|
||||
serialized_task = redis_client.rpop(self.queue)
|
||||
if not serialized_task:
|
||||
break
|
||||
|
||||
if isinstance(serialized_task, bytes):
|
||||
serialized_task = serialized_task.decode('utf-8')
|
||||
|
||||
# Check if use TaskWrapper or not
|
||||
if serialized_task.startswith(TASK_WRAPPER_PREFIX):
|
||||
try:
|
||||
wrapper_data = serialized_task[len(TASK_WRAPPER_PREFIX):]
|
||||
wrapper = TaskWrapper.deserialize(wrapper_data)
|
||||
tasks.append(wrapper.data)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
tasks.append(serialized_task)
|
||||
else:
|
||||
tasks.append(serialized_task)
|
||||
|
||||
return tasks
|
||||
|
||||
def get_next_task(self) -> T | None:
|
||||
tasks = self.pull_tasks(1)
|
||||
return tasks[0] if tasks else None
|
||||
Reference in New Issue
Block a user