mirror of
https://github.com/langgenius/dify.git
synced 2026-03-17 12:57:51 +08:00
WIP: resume
This commit is contained in:
@ -1,71 +0,0 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, TypeAlias, final
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandParams:
|
||||
# `next_node_instance` is the instance of the next node to run.
|
||||
next_node: BaseNode
|
||||
|
||||
|
||||
class _CommandTag(StrEnum):
|
||||
SUSPEND = "suspend"
|
||||
STOP = "stop"
|
||||
CONTINUE = "continue"
|
||||
|
||||
|
||||
# Note: Avoid using the `_Command` class directly.
|
||||
# Instead, use `CommandTypes` for type annotations.
|
||||
class _Command(BaseModel, abc.ABC):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tag: _CommandTag
|
||||
|
||||
@field_validator("tag")
|
||||
@classmethod
|
||||
def validate_value_type(cls, value):
|
||||
if value != cls.model_fields["tag"].default:
|
||||
raise ValueError("Cannot modify 'tag'")
|
||||
return value
|
||||
|
||||
|
||||
@final
|
||||
class StopCommand(_Command):
|
||||
tag: _CommandTag = _CommandTag.STOP
|
||||
|
||||
|
||||
@final
|
||||
class SuspendCommand(_Command):
|
||||
tag: _CommandTag = _CommandTag.SUSPEND
|
||||
|
||||
|
||||
@final
|
||||
class ContinueCommand(_Command):
|
||||
tag: _CommandTag = _CommandTag.CONTINUE
|
||||
|
||||
|
||||
def _get_command_tag(command: _Command):
|
||||
return command.tag
|
||||
|
||||
|
||||
CommandTypes: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[StopCommand, Tag(_CommandTag.STOP)]
|
||||
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
|
||||
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
|
||||
),
|
||||
Discriminator(_get_command_tag),
|
||||
]
|
||||
|
||||
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
|
||||
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
|
||||
#
|
||||
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
|
||||
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]
|
||||
@ -1,8 +1,3 @@
|
||||
"""
|
||||
Human Input node implementation.
|
||||
"""
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode", "HumanInputNodeData"]
|
||||
|
||||
@ -269,16 +269,6 @@ class HumanInputNodeData(BaseNodeData):
|
||||
return variable_mappings
|
||||
|
||||
|
||||
class HumanInputRequired(BaseModel):
|
||||
"""Event data for human input required."""
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
form_content: str
|
||||
inputs: list[FormInput]
|
||||
web_app_form_token: Optional[str] = None
|
||||
|
||||
|
||||
class FormDefinition(BaseModel):
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
@ -18,11 +17,6 @@ _SELECTED_BRANCH_KEY = "selected_branch"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _FormSubmissionResult:
|
||||
action_id: str
|
||||
|
||||
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@ -5,7 +5,7 @@ import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
@ -13,6 +13,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
@ -59,7 +62,7 @@ class GraphExecutionProtocol(Protocol):
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
pause_reasons: Sequence[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
|
||||
Reference in New Issue
Block a user