WIP: resume

This commit is contained in:
QuantumGhost
2025-11-21 10:13:20 +08:00
parent c0e15b9e1b
commit c0f1aeddbe
49 changed files with 2160 additions and 1445 deletions

View File

@ -1,5 +1,3 @@
from .workflow_execute_task import chatflow_execute_task
__all__ = [
"chatflow_execute_taskch"
]
__all__ = ["chatflow_execute_task"]

View File

@ -8,18 +8,21 @@ from typing import Annotated, Any, TypeAlias, Union
from celery import shared_task
from flask import current_app, json
from pydantic import BaseModel, Discriminator, Field, Tag
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_database import db
from libs.flask_utils import set_login_user
from models.account import Account
from models.model import App, AppMode, EndUser
from models.workflow import Workflow
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import App, AppMode, Conversation, EndUser, Message
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
@ -129,6 +132,11 @@ class _ChatflowRunner:
workflow = session.get(Workflow, exec_params.workflow_id)
app = session.get(App, workflow.app_id)
pause_config = PauseStateLayerConfig(
session_factory=self._session_factory,
state_owner_user_id=workflow.created_by,
)
user = self._resolve_user()
chat_generator = AdvancedChatAppGenerator()
@ -144,6 +152,7 @@ class _ChatflowRunner:
invoke_from=exec_params.invoke_from,
streaming=exec_params.streaming,
workflow_run_id=workflow_run_id,
pause_state_config=pause_config,
)
if not exec_params.streaming:
return response
@ -174,11 +183,135 @@ class _ChatflowRunner:
return user
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
role = CreatorUserRole(workflow_run.created_by_role)
if role == CreatorUserRole.ACCOUNT:
user = session.get(Account, workflow_run.created_by)
if user:
user.set_tenant_id(workflow_run.tenant_id)
return user
return session.get(EndUser, workflow_run.created_by)
@shared_task(queue="chatflow_execute")
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
exec_params = ChatflowExecutionParams.model_validate_json(payload)
print("chatflow_execute_task run with params", exec_params)
logger.info("chatflow_execute_task run with params: %s", exec_params)
runner = _ChatflowRunner(db.engine, exec_params=exec_params)
return runner.run()
@shared_task(queue="chatflow_execute", name="resume_chatflow_execution")
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
workflow_run_id = payload["workflow_run_id"]
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory)
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
if pause_entity is None:
logger.warning("No pause entity found for workflow run %s", workflow_run_id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception:
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
return
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, AdvancedChatAppGenerateEntity):
logger.error(
"Resumption entity is not AdvancedChatAppGenerateEntity for workflow run %s (found %s)",
workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is None:
logger.warning("Workflow run %s not found during resume", workflow_run_id)
return
workflow = session.get(Workflow, workflow_run.workflow_id)
if workflow is None:
logger.warning("Workflow %s not found during resume", workflow_run.workflow_id)
return
app_model = session.get(App, workflow_run.app_id)
if app_model is None:
logger.warning("App %s not found during resume", workflow_run.app_id)
return
if generate_entity.conversation_id is None:
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
return
conversation = session.get(Conversation, generate_entity.conversation_id)
if conversation is None:
logger.warning(
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
)
return
message = session.scalar(
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
)
if message is None:
logger.warning("Message not found for workflow run %s", workflow_run_id)
return
user = _resolve_user_for_run(session, workflow_run)
if user is None:
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
return
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
generator = AdvancedChatAppGenerator()
try:
generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
conversation=conversation,
message=message,
application_generate_entity=generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_config,
)
except Exception:
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise

View File

@ -5,6 +5,7 @@ These tasks handle workflow execution for different subscription tiers
with appropriate retry policies and error handling.
"""
import logging
from datetime import UTC, datetime
from typing import Any
@ -14,23 +15,37 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer
from core.app.layers.trigger_post_layer import TriggerPostLayer
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from repositories.factory import DifyAPIRepositoryFactory
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import WorkflowNotFoundError
from services.workflow.entities import (
TriggerData,
WorkflowResumeTaskData,
WorkflowTaskData,
)
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy
logger = logging.getLogger(__name__)
_TRIGGER_TO_RUN_SOURCE = {
AppTriggerType.TRIGGER_WEBHOOK: WorkflowRunTriggeredFrom.WEBHOOK,
AppTriggerType.TRIGGER_SCHEDULE: WorkflowRunTriggeredFrom.SCHEDULE,
AppTriggerType.TRIGGER_PLUGIN: WorkflowRunTriggeredFrom.PLUGIN,
}
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]):
@ -144,6 +159,11 @@ def _execute_workflow_common(
if trigger_data.workflow_id:
args["workflow_id"] = str(trigger_data.workflow_id)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
# Execute the workflow with the trigger type
generator.generate(
app_model=app_model,
@ -159,6 +179,7 @@ def _execute_workflow_common(
# TODO: Re-enable TimeSliceLayer after the HITL release.
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
],
pause_state_config=pause_config,
)
except Exception as e:
@ -176,6 +197,117 @@ def _execute_workflow_common(
session.commit()
@shared_task(name="resume_workflow_execution")
def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
"""Resume a paused workflow run via Celery."""
task_data = WorkflowResumeTaskData.model_validate(task_data_dict)
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
with session_factory() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
if not trigger_log:
logger.warning("Trigger log not found for resumption: %s", task_data.workflow_trigger_log_id)
return
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
if pause_entity is None:
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception as exc:
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
raise exc
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
logger.error(
"Unsupported resumption entity for workflow run %s: %s",
task_data.workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
if workflow is None:
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
if app_model is None:
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
user = _get_user(session, trigger_log)
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=trigger_log.queue_name,
schedule_strategy=AsyncWorkflowSystemStrategy,
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
)
cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
try:
trigger_type = AppTriggerType(trigger_log.trigger_type)
except ValueError:
trigger_type = AppTriggerType.UNKNOWN
triggered_from = _TRIGGER_TO_RUN_SOURCE.get(trigger_type, WorkflowRunTriggeredFrom.APP_RUN)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
trigger_log.status = WorkflowTriggerStatus.RUNNING
trigger_log_repo.update(trigger_log)
session.commit()
generator = WorkflowAppGenerator()
start_time = datetime.now(UTC)
try:
generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=[
TimeSliceLayer(cfs_plan_scheduler),
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
],
pause_state_config=pause_config,
)
except Exception as exc:
trigger_log.status = WorkflowTriggerStatus.FAILED
trigger_log.error = str(exc)
trigger_log.finished_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
raise
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))