WIP: resume

This commit is contained in:
QuantumGhost
2025-11-21 10:13:20 +08:00
parent c0e15b9e1b
commit c0f1aeddbe
49 changed files with 2160 additions and 1445 deletions

View File

@ -1,71 +0,0 @@
import abc
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import Annotated, TypeAlias, final
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.workflow.nodes.base import BaseNode
@dataclass(frozen=True)
class CommandParams:
# `next_node_instance` is the instance of the next node to run.
next_node: BaseNode
class _CommandTag(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
# Note: Avoid using the `_Command` class directly.
# Instead, use `CommandTypes` for type annotations.
class _Command(BaseModel, abc.ABC):
model_config = ConfigDict(frozen=True)
tag: _CommandTag
@field_validator("tag")
@classmethod
def validate_value_type(cls, value):
if value != cls.model_fields["tag"].default:
raise ValueError("Cannot modify 'tag'")
return value
@final
class StopCommand(_Command):
tag: _CommandTag = _CommandTag.STOP
@final
class SuspendCommand(_Command):
tag: _CommandTag = _CommandTag.SUSPEND
@final
class ContinueCommand(_Command):
tag: _CommandTag = _CommandTag.CONTINUE
def _get_command_tag(command: _Command):
return command.tag
CommandTypes: TypeAlias = Annotated[
(
Annotated[StopCommand, Tag(_CommandTag.STOP)]
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
),
Discriminator(_get_command_tag),
]
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]