WIP: resume

This commit is contained in:
QuantumGhost
2025-11-21 10:13:20 +08:00
parent c0e15b9e1b
commit c0f1aeddbe
49 changed files with 2160 additions and 1445 deletions

View File

@ -1,71 +0,0 @@
import abc
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import Annotated, TypeAlias, final
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.workflow.nodes.base import BaseNode
@dataclass(frozen=True)
class CommandParams:
# `next_node_instance` is the instance of the next node to run.
next_node: BaseNode
class _CommandTag(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
# Note: Avoid using the `_Command` class directly.
# Instead, use `CommandTypes` for type annotations.
class _Command(BaseModel, abc.ABC):
model_config = ConfigDict(frozen=True)
tag: _CommandTag
@field_validator("tag")
@classmethod
def validate_value_type(cls, value):
if value != cls.model_fields["tag"].default:
raise ValueError("Cannot modify 'tag'")
return value
@final
class StopCommand(_Command):
tag: _CommandTag = _CommandTag.STOP
@final
class SuspendCommand(_Command):
tag: _CommandTag = _CommandTag.SUSPEND
@final
class ContinueCommand(_Command):
tag: _CommandTag = _CommandTag.CONTINUE
def _get_command_tag(command: _Command):
return command.tag
CommandTypes: TypeAlias = Annotated[
(
Annotated[StopCommand, Tag(_CommandTag.STOP)]
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
),
Discriminator(_get_command_tag),
]
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]

View File

@ -1,8 +1,3 @@
"""
Human Input node implementation.
"""
from .entities import HumanInputNodeData
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode", "HumanInputNodeData"]

View File

@ -269,16 +269,6 @@ class HumanInputNodeData(BaseNodeData):
return variable_mappings
class HumanInputRequired(BaseModel):
"""Event data for human input required."""
form_id: str
node_id: str
form_content: str
inputs: list[FormInput]
web_app_form_token: Optional[str] = None
class FormDefinition(BaseModel):
form_content: str
inputs: list[FormInput] = Field(default_factory=list)

View File

@ -1,4 +1,3 @@
import dataclasses
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any
@ -18,11 +17,6 @@ _SELECTED_BRANCH_KEY = "selected_branch"
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class _FormSubmissionResult:
action_id: str
class HumanInputNode(Node[HumanInputNodeData]):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH

View File

@ -5,7 +5,7 @@ import json
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from pydantic.json import pydantic_encoder
@ -13,6 +13,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.entities.pause_reason import PauseReason
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
@ -59,7 +62,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
pause_reasons: Sequence[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""