mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -38,6 +38,7 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@ -251,6 +252,116 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
||||
Creates a pause state for a workflow run, storing the current execution
|
||||
state and marking the workflow as paused. This is used when a workflow
|
||||
needs to be suspended and later resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to pause
|
||||
state_owner_user_id: User ID who owns the pause state for file storage
|
||||
state: Serialized workflow execution state (JSON string)
|
||||
|
||||
Returns:
|
||||
WorkflowPauseEntity representing the created pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||
RuntimeError: If workflow is already paused or in invalid state
|
||||
"""
|
||||
# NOTE: we may get rid of the `state_owner_user_id` in parameter list.
|
||||
# However, removing it would require an extra for `Workflow` model
|
||||
# while creating pause.
|
||||
...
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Resume a paused workflow.
|
||||
|
||||
Marks a paused workflow as resumed, set the `resumed_at` field of WorkflowPauseEntity
|
||||
and returning the workflow to running status. Returns the pause entity
|
||||
that was resumed.
|
||||
|
||||
The returned `WorkflowPauseEntity` model has `resumed_at` set.
|
||||
|
||||
NOTE: this method does not delete the correspond `WorkflowPauseEntity` record and associated states.
|
||||
It's the callers responsibility to clear the correspond state with `delete_workflow_pause`.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to resume
|
||||
pause_entity: The pause entity to resume
|
||||
|
||||
Returns:
|
||||
WorkflowPauseEntity representing the resumed pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
RuntimeError: If workflow is not paused or already resumed
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a workflow pause state.
|
||||
|
||||
Permanently removes the pause state for a workflow run, including
|
||||
the stored state file. Used for cleanup operations when a paused
|
||||
workflow is no longer needed.
|
||||
|
||||
Args:
|
||||
pause_entity: The pause entity to delete
|
||||
|
||||
Raises:
|
||||
ValueError: If pause_entity is invalid
|
||||
RuntimeError: If workflow is not paused
|
||||
|
||||
Note:
|
||||
This operation is irreversible. The stored workflow state will be
|
||||
permanently deleted along with the pause record.
|
||||
"""
|
||||
...
|
||||
|
||||
def prune_pauses(
|
||||
self,
|
||||
expiration: datetime,
|
||||
resumption_expiration: datetime,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Clean up expired and old pause states.
|
||||
|
||||
Removes pause states that have expired (created before expiration time)
|
||||
and pause states that were resumed more than resumption_duration ago.
|
||||
This is used for maintenance and cleanup operations.
|
||||
|
||||
Args:
|
||||
expiration: Remove pause states created before this time
|
||||
resumption_expiration: Remove pause states resumed before this time
|
||||
limit: maximum number of records deleted in one call
|
||||
|
||||
Returns:
|
||||
a list of ids for pause records that were pruned
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
...
|
||||
|
||||
def get_daily_runs_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -20,19 +20,26 @@ Implementation Notes:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import and_, delete, func, null, or_, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.types import (
|
||||
@ -45,6 +52,10 @@ from repositories.types import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WorkflowRunError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||
@ -301,6 +312,281 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||
return total_deleted
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
||||
Creates a pause state for a workflow run, storing the current execution
|
||||
state and marking the workflow as paused. This is used when a workflow
|
||||
needs to be suspended and later resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to pause
|
||||
state_owner_user_id: User ID who owns the pause state for file storage
|
||||
state: Serialized workflow execution state (JSON string)
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity representing the created pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||
RuntimeError: If workflow is already paused or in invalid state
|
||||
"""
|
||||
previous_pause_model_query = select(WorkflowPauseModel).where(
|
||||
WorkflowPauseModel.workflow_run_id == workflow_run_id
|
||||
)
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the workflow run
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
# Check if workflow is in RUNNING status
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
raise _WorkflowRunError(
|
||||
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||
)
|
||||
#
|
||||
previous_pause = session.scalars(previous_pause_model_query).first()
|
||||
if previous_pause:
|
||||
self._delete_pause_model(session, previous_pause)
|
||||
# we need to flush here to ensure that the old one is actually deleted.
|
||||
session.flush()
|
||||
|
||||
state_obj_key = f"workflow-state-{uuid.uuid4()}.json"
|
||||
storage.save(state_obj_key, state.encode())
|
||||
# Upload the state file
|
||||
|
||||
# Create the pause record
|
||||
pause_model = WorkflowPauseModel()
|
||||
pause_model.id = str(uuidv7())
|
||||
pause_model.workflow_id = workflow_run.workflow_id
|
||||
pause_model.workflow_run_id = workflow_run.id
|
||||
pause_model.state_object_key = state_obj_key
|
||||
pause_model.created_at = naive_utc_now()
|
||||
|
||||
# Update workflow run status
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Save everything in a transaction
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
|
||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> WorkflowPauseEntity | None:
|
||||
"""
|
||||
Get an existing workflow pause state.
|
||||
|
||||
Retrieves the pause state for a specific workflow run if it exists.
|
||||
Used to check if a workflow is paused and to retrieve its saved state.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to get pause state for
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity if pause state exists, None otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
# Query workflow run with pause and state file
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
return None
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Resume a paused workflow.
|
||||
|
||||
Marks a paused workflow as resumed, clearing the pause state and
|
||||
returning the workflow to running status. Returns the pause entity
|
||||
that was resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to resume
|
||||
pause_entity: The pause entity to resume
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity representing the resumed pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
RuntimeError: If workflow is not paused or already resumed
|
||||
"""
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the workflow run with pause
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
if workflow_run.status != WorkflowExecutionStatus.PAUSED:
|
||||
raise _WorkflowRunError(
|
||||
f"WorkflowRun is not in PAUSED status, workflow_run_id={workflow_run_id}, "
|
||||
f"current_status={workflow_run.status}"
|
||||
)
|
||||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
raise _WorkflowRunError(f"No pause state found for workflow run: {workflow_run_id}")
|
||||
|
||||
if pause_model.id != pause_entity.id:
|
||||
raise _WorkflowRunError(
|
||||
"different id in WorkflowPause and WorkflowPauseEntity, "
|
||||
f"WorkflowPause.id={pause_model.id}, "
|
||||
f"WorkflowPauseEntity.id={pause_entity.id}"
|
||||
)
|
||||
|
||||
if pause_model.resumed_at is not None:
|
||||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
|
||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a workflow pause state.
|
||||
|
||||
Permanently removes the pause state for a workflow run, including
|
||||
the stored state file. Used for cleanup operations when a paused
|
||||
workflow is no longer needed.
|
||||
|
||||
Args:
|
||||
pause_entity: The pause entity to delete
|
||||
|
||||
Raises:
|
||||
ValueError: If pause_entity is invalid
|
||||
_WorkflowRunError: If workflow is not paused
|
||||
|
||||
Note:
|
||||
This operation is irreversible. The stored workflow state will be
|
||||
permanently deleted along with the pause record.
|
||||
"""
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the pause model by ID
|
||||
pause_model = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
if pause_model is None:
|
||||
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
|
||||
self._delete_pause_model(session, pause_model)
|
||||
|
||||
@staticmethod
|
||||
def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
|
||||
storage.delete(pause_model.state_object_key)
|
||||
|
||||
# Delete the pause record
|
||||
session.delete(pause_model)
|
||||
|
||||
logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id)
|
||||
|
||||
def prune_pauses(
|
||||
self,
|
||||
expiration: datetime,
|
||||
resumption_expiration: datetime,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Clean up expired and old pause states.
|
||||
|
||||
Removes pause states that have expired (created before expiration time)
|
||||
and pause states that were resumed more than resumption_duration ago.
|
||||
This is used for maintenance and cleanup operations.
|
||||
|
||||
Args:
|
||||
expiration: Remove pause states created before this time
|
||||
resumption_expiration: Remove pause states resumed before this time
|
||||
limit: maximum number of records deleted in one call
|
||||
|
||||
Returns:
|
||||
a list of ids for pause records that were pruned
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
_limit: int = limit or 1000
|
||||
pruned_record_ids: list[str] = []
|
||||
cond = or_(
|
||||
WorkflowPauseModel.created_at < expiration,
|
||||
and_(
|
||||
WorkflowPauseModel.resumed_at.is_not(null()),
|
||||
WorkflowPauseModel.resumed_at < resumption_expiration,
|
||||
),
|
||||
)
|
||||
# First, collect pause records to delete with their state files
|
||||
# Expired pauses (created before expiration time)
|
||||
stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
# Old resumed pauses (resumed more than resumption_duration ago)
|
||||
|
||||
# Get all records to delete
|
||||
pauses_to_delete = session.scalars(stmt).all()
|
||||
|
||||
# Delete state files from storage
|
||||
for pause in pauses_to_delete:
|
||||
with self._session_maker(expire_on_commit=False) as session, session.begin():
|
||||
# todo: this issues a separate query for each WorkflowPauseModel record.
|
||||
# consider batching this lookup.
|
||||
try:
|
||||
storage.delete(pause.state_object_key)
|
||||
logger.info(
|
||||
"Deleted state object for pause, pause_id=%s, object_key=%s",
|
||||
pause.id,
|
||||
pause.state_object_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete state file for pause, pause_id=%s, object_key=%s",
|
||||
pause.id,
|
||||
pause.state_object_key,
|
||||
)
|
||||
continue
|
||||
session.delete(pause)
|
||||
pruned_record_ids.append(pause.id)
|
||||
logger.info(
|
||||
"workflow pause records deleted, id=%s, resumed_at=%s",
|
||||
pause.id,
|
||||
pause.resumed_at,
|
||||
)
|
||||
|
||||
return pruned_record_ids
|
||||
|
||||
def get_daily_runs_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -510,3 +796,69 @@ GROUP BY
|
||||
)
|
||||
|
||||
return cast(list[AverageInteractionStats], response_data)
|
||||
|
||||
|
||||
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
"""
|
||||
Private implementation of WorkflowPauseEntity for SQLAlchemy repository.
|
||||
|
||||
This implementation is internal to the repository layer and provides
|
||||
the concrete implementation of the WorkflowPauseEntity interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pause_model: WorkflowPauseModel,
|
||||
) -> None:
|
||||
self._pause_model = pause_model
|
||||
self._cached_state: bytes | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
|
||||
"""
|
||||
Create a _PrivateWorkflowPauseEntity from database models.
|
||||
|
||||
Args:
|
||||
workflow_pause_model: The WorkflowPause database model
|
||||
upload_file_model: The UploadFile database model
|
||||
|
||||
Returns:
|
||||
_PrivateWorkflowPauseEntity: The constructed entity
|
||||
|
||||
Raises:
|
||||
ValueError: If required model attributes are missing
|
||||
"""
|
||||
return cls(pause_model=workflow_pause_model)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._pause_model.id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self._pause_model.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: The workflow state as a dictionary
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the state file cannot be found
|
||||
IOError: If there are issues reading the state file
|
||||
_Workflow: If the state cannot be deserialized properly
|
||||
"""
|
||||
if self._cached_state is not None:
|
||||
return self._cached_state
|
||||
|
||||
# Load the state from storage
|
||||
state_data = storage.load(self._pause_model.state_object_key)
|
||||
self._cached_state = state_data
|
||||
return state_data
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return self._pause_model.resumed_at
|
||||
|
||||
Reference in New Issue
Block a user