mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
feat: knowledge pipeline (#25360)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Hanqing Zhao <sherry9277@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
@ -30,6 +30,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
|
||||
171
api/tasks/deal_dataset_index_update_task.py
Normal file
171
api/tasks/deal_dataset_index_update_task.py
Normal file
@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
"""
|
||||
Async deal dataset from index
|
||||
:param dataset_id: dataset_id
|
||||
:param action: action
|
||||
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
|
||||
"""
|
||||
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
@ -47,6 +47,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
|
||||
175
api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
Normal file
175
api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
Normal file
@ -0,0 +1,175 @@
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="priority_pipeline")
|
||||
def priority_rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = (
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in priority pipeline task")
|
||||
raise
|
||||
196
api/tasks/rag_pipeline/rag_pipeline_run_task.py
Normal file
196
api/tasks/rag_pipeline/rag_pipeline_run_task.py
Normal file
@ -0,0 +1,196 @@
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
|
||||
|
||||
if next_file_id:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
redis_client.delete(tenant_pipeline_task_key)
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = (
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
raise
|
||||
@ -354,6 +354,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
"""
|
||||
Delete draft variables for an app in batches.
|
||||
|
||||
This function now handles cleanup of associated Offload data including:
|
||||
- WorkflowDraftVariableFile records
|
||||
- UploadFile records
|
||||
- Object storage files
|
||||
|
||||
Args:
|
||||
app_id: The ID of the app whose draft variables should be deleted
|
||||
batch_size: Number of records to delete per batch
|
||||
@ -365,22 +370,31 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
total_deleted = 0
|
||||
total_files_deleted = 0
|
||||
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
# Get a batch of draft variable IDs
|
||||
# Get a batch of draft variable IDs along with their file_ids
|
||||
query_sql = """
|
||||
SELECT id FROM workflow_draft_variables
|
||||
SELECT id, file_id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
|
||||
|
||||
draft_var_ids = [row[0] for row in result]
|
||||
if not draft_var_ids:
|
||||
rows = list(result)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
draft_var_ids = [row[0] for row in rows]
|
||||
file_ids = [row[1] for row in rows if row[1] is not None]
|
||||
|
||||
# Clean up associated Offload data first
|
||||
if file_ids:
|
||||
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
|
||||
total_files_deleted += files_deleted
|
||||
|
||||
# Delete the draft variables
|
||||
delete_sql = """
|
||||
DELETE FROM workflow_draft_variables
|
||||
WHERE id IN :ids
|
||||
@ -391,11 +405,86 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
|
||||
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
|
||||
|
||||
logger.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green"))
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Deleted {total_deleted} total draft variables for app {app_id}. "
|
||||
f"Cleaned up {total_files_deleted} total associated files.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
return total_deleted
|
||||
|
||||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str):
|
||||
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
||||
"""
|
||||
Delete Offload data associated with WorkflowDraftVariable file_ids.
|
||||
|
||||
This function:
|
||||
1. Finds WorkflowDraftVariableFile records by file_ids
|
||||
2. Deletes associated files from object storage
|
||||
3. Deletes UploadFile records
|
||||
4. Deletes WorkflowDraftVariableFile records
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
file_ids: List of WorkflowDraftVariableFile IDs
|
||||
|
||||
Returns:
|
||||
Number of files cleaned up
|
||||
"""
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if not file_ids:
|
||||
return 0
|
||||
|
||||
files_deleted = 0
|
||||
|
||||
try:
|
||||
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
|
||||
query_sql = """
|
||||
SELECT wdvf.id, uf.key, uf.id as upload_file_id
|
||||
FROM workflow_draft_variable_files wdvf
|
||||
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
|
||||
WHERE wdvf.id IN :file_ids
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
|
||||
file_records = list(result)
|
||||
|
||||
# Delete from object storage and collect upload file IDs
|
||||
upload_file_ids = []
|
||||
for _, storage_key, upload_file_id in file_records:
|
||||
try:
|
||||
storage.delete(storage_key)
|
||||
upload_file_ids.append(upload_file_id)
|
||||
files_deleted += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to delete storage object %s", storage_key)
|
||||
# Continue with database cleanup even if storage deletion fails
|
||||
upload_file_ids.append(upload_file_id)
|
||||
|
||||
# Delete UploadFile records
|
||||
if upload_file_ids:
|
||||
delete_upload_files_sql = """
|
||||
DELETE FROM upload_files
|
||||
WHERE id IN :upload_file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
|
||||
|
||||
# Delete WorkflowDraftVariableFile records
|
||||
delete_variable_files_sql = """
|
||||
DELETE FROM workflow_draft_variable_files
|
||||
WHERE id IN :file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
|
||||
|
||||
except Exception:
|
||||
logging.exception("Error deleting draft variable offload data:")
|
||||
# Don't raise, as we want to continue with the main deletion process
|
||||
|
||||
return files_deleted
|
||||
|
||||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query_sql), params)
|
||||
|
||||
@ -10,20 +10,23 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
:param user_id:
|
||||
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
@ -31,11 +34,19 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
tenant_id = dataset.tenant_id
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = f"document_{document_id}_is_retried"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
@ -87,8 +98,12 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
if dataset.runtime_mode == "rag_pipeline":
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.retry_error_document(dataset, document, user)
|
||||
else:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
|
||||
27
api/tasks/workflow_draft_var_tasks.py
Normal file
27
api/tasks/workflow_draft_var_tasks.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@shared_task(queue="workflow_draft_var", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
deletions: list[DraftVarFileDeletion],
|
||||
):
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
srv = WorkflowDraftVariableService(session=session)
|
||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||
Reference in New Issue
Block a user