chore: use from __future__ import annotations (#30254)

Co-authored-by: Dev <dev@Devs-MacBook-Pro-4.local>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Sara Rasool
2026-01-06 19:57:20 +05:00
committed by GitHub
parent 0294555893
commit 4f0fb6df2b
50 changed files with 253 additions and 163 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from configs import dify_config
@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum
from enum import StrEnum
from typing import Any
@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
ONLINE_DRIVE = "online_drive"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
def value_of(cls, value: str) -> DatasourceProviderType:
"""
Get value of given mode.
@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
typ: DatasourceParameterType,
required: bool,
options: list[str] | None = None,
) -> "DatasourceParameter":
) -> DatasourceParameter:
"""
get a simple datasource parameter
@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
def empty(cls) -> DatasourceInvokeMeta:
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
def error_instance(cls, error: str) -> DatasourceInvokeMeta:
"""
Get an instance of DatasourceInvokeMeta with error
"""

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from datetime import datetime
from enum import StrEnum
@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
return cls(

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Union
@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
def value_of(cls, value: str) -> ProviderConfig.Type:
"""
Get value of given mode.

View File

@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT
]""",
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> "PromptMessageRole":
def value_of(cls, value: str) -> PromptMessageRole:
"""
Get value of given mode.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any
@ -20,7 +22,7 @@ class ModelType(StrEnum):
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
def value_of(cls, origin_model_type: str) -> ModelType:
"""
Get model type from origin model type.
@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> "DefaultParameterName":
def value_of(cls, value: Any) -> DefaultParameterName:
"""
Get parameter name from value.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Sequence
@ -38,7 +40,7 @@ class ModelProviderFactory:
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
@ -76,7 +78,7 @@ class ModelProviderFactory:
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
"""
Get plugin model provider
:param provider: provider name

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
def of(cls, credential_type: str) -> CredentialType:
type_name = credential_type.lower()
if type_name in {"api-key", "api_key"}:
return cls.API_KEY

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import contextlib
import json
import logging
@ -6,7 +8,7 @@ import re
import threading
import time
import uuid
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
Manages connection reuse across ClickzettaVector instances.
"""
_instance: Optional["ClickzettaConnectionPool"] = None
_instance: ClickzettaConnectionPool | None = None
_lock = threading.Lock()
def __init__(self):
@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
self._start_cleanup_thread()
@classmethod
def get_instance(cls) -> "ClickzettaConnectionPool":
def get_instance(cls) -> ClickzettaConnectionPool:
"""Get singleton instance of connection pool."""
if cls._instance is None:
with cls._lock:
@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
)
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
def _create_connection(self, config: ClickzettaConfig) -> Connection:
"""Create a new ClickZetta connection."""
max_retries = 3
retry_delay = 1.0
@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
def _configure_connection(self, connection: "Connection"):
def _configure_connection(self, connection: Connection):
"""Configure connection session settings."""
try:
with connection.cursor() as cursor:
@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
except Exception:
logger.exception("Failed to configure connection, continuing with defaults")
def _is_connection_valid(self, connection: "Connection") -> bool:
def _is_connection_valid(self, connection: Connection) -> bool:
"""Check if connection is still valid."""
try:
with connection.cursor() as cursor:
@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
except Exception:
return False
def get_connection(self, config: ClickzettaConfig) -> "Connection":
def get_connection(self, config: ClickzettaConfig) -> Connection:
"""Get a connection from the pool or create a new one."""
config_key = self._get_config_key(config)
@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
# No valid connection found, create new one
return self._create_connection(config)
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
def return_connection(self, config: ClickzettaConfig, connection: Connection):
"""Return a connection to the pool."""
config_key = self._get_config_key(config)
@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
self._connection_pool = ClickzettaConnectionPool.get_instance()
self._init_write_queue()
def _get_connection(self) -> "Connection":
def _get_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._connection_pool.get_connection(self._config)
def _return_connection(self, connection: "Connection"):
def _return_connection(self, connection: Connection):
"""Return a connection to the pool."""
self._connection_pool.return_connection(self._config, connection)
class ConnectionContext:
"""Context manager for borrowing and returning connections."""
def __init__(self, vector_instance: "ClickzettaVector"):
def __init__(self, vector_instance: ClickzettaVector):
self.vector = vector_instance
self.connection: Connection | None = None
def __enter__(self) -> "Connection":
def __enter__(self) -> Connection:
self.connection = self.vector._get_connection()
return self.connection
@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
if self.connection:
self.vector._return_connection(self.connection)
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
"""Get a connection context manager."""
return self.ConnectionContext(self)
@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
"""Return the vector database type."""
return "clickzetta"
def _ensure_connection(self) -> "Connection":
def _ensure_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._get_connection()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
@ -22,7 +24,7 @@ class DatasetDocumentStore:
self._document_id = document_id
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
return cls(**config_dict)
def to_dict(self) -> dict[str, Any]:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any
@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
return self.model_dump_json()
@classmethod
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
def deserialize(cls, serialized_data: str) -> TaskWrapper:
return cls.model_validate_json(serialized_data)

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import json
import logging
import threading
from collections.abc import Mapping, MutableMapping
from pathlib import Path
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar
class SchemaRegistry:
@ -11,7 +13,7 @@ class SchemaRegistry:
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
_default_instance: ClassVar[SchemaRegistry | None] = None
_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self, base_dir: str):
@ -20,7 +22,7 @@ class SchemaRegistry:
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
@classmethod
def default_registry(cls) -> "SchemaRegistry":
def default_registry(cls) -> SchemaRegistry:
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
if cls._default_instance is None:
with cls._lock:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
@ -24,7 +26,7 @@ class Tool(ABC):
self.entity = entity
self.runtime = runtime
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
"""
fork a new tool with metadata
:return: the new tool
@ -166,7 +168,7 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
def create_file_message(self, file: File) -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
@ -24,7 +26,7 @@ class BuiltinTool(Tool):
super().__init__(**kwargs)
self.provider = provider
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
"""
fork a new tool with metadata
:return: the new tool

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from pydantic import Field
from sqlalchemy import select
@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = []
@classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
credentials_schema = [
ProviderConfig(
name="auth_type",

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import contextlib
from collections.abc import Mapping
@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
MCP = auto()
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
def value_of(cls, value: str) -> ToolProviderType:
"""
Get value of given mode.
@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
OPENAI_ACTIONS = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
def value_of(cls, value: str) -> ApiProviderSchemaType:
"""
Get value of given mode.
@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
API_KEY_QUERY = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
def value_of(cls, value: str) -> ApiProviderAuthType:
"""
Get value of given mode.
@ -307,7 +309,7 @@ class ToolParameter(PluginParameter):
typ: ToolParameterType,
required: bool,
options: list[str] | None = None,
) -> "ToolParameter":
) -> ToolParameter:
"""
get a simple tool parameter
@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
def empty(cls) -> "ToolInvokeMeta":
def empty(cls) -> ToolInvokeMeta:
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "ToolInvokeMeta":
def error_instance(cls, error: str) -> ToolInvokeMeta:
"""
Get an instance of ToolInvokeMeta with error
"""

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import json
import logging
@ -118,7 +120,7 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
@ -46,7 +48,7 @@ class PluginTool(Tool):
message_id=message_id,
)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
return PluginTool(
entity=self.entity,
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Mapping
from pydantic import Field
@ -47,7 +49,7 @@ class WorkflowToolProviderController(ToolProviderController):
self.provider_id = provider_id
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
with session_factory.create_session() as session, session.begin():
app = session.get(App, db_provider.app_id)
if not app:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@ -181,7 +183,7 @@ class WorkflowTool(Tool):
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
"""
fork a new tool with metadata

View File

@ -1,6 +1,8 @@
from __future__ import annotations
from collections.abc import Mapping
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from core.file.models import File
@ -52,7 +54,7 @@ class SegmentType(StrEnum):
return self in _ARRAY_TYPES
@classmethod
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
def infer_segment_type(cls, value: Any) -> SegmentType | None:
"""
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
@ -173,7 +175,7 @@ class SegmentType(StrEnum):
raise AssertionError("this statement should be unreachable.")
@staticmethod
def cast_value(value: Any, type_: "SegmentType"):
def cast_value(value: Any, type_: SegmentType):
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
@ -193,7 +195,7 @@ class SegmentType(StrEnum):
return [int(i) for i in value]
return value
def exposed_type(self) -> "SegmentType":
def exposed_type(self) -> SegmentType:
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
@ -202,7 +204,7 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
def element_type(self) -> "SegmentType | None":
def element_type(self) -> SegmentType | None:
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
@ -217,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: "SegmentType"):
def get_zero_value(t: SegmentType):
# Lazy import to avoid circular dependency
from factories import variable_factory

View File

@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
from typing import Any
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
) -> WorkflowExecution:
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -175,7 +177,7 @@ class Graph:
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
@ -197,7 +199,7 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@ -284,9 +286,9 @@ class Graph:
cls,
*,
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
node_factory: NodeFactory,
root_node_id: str | None = None,
) -> "Graph":
) -> Graph:
"""
Initialize graph
@ -383,7 +385,7 @@ class GraphBuilder:
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
def add_root(self, node: Node) -> GraphBuilder:
"""Register the root node. Must be called exactly once."""
if self._nodes:
@ -398,7 +400,7 @@ class GraphBuilder:
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
) -> GraphBuilder:
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
@ -419,7 +421,7 @@ class GraphBuilder:
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:

View File

@ -5,6 +5,8 @@ This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
from __future__ import annotations
import contextvars
import logging
import queue
@ -232,7 +234,7 @@ class GraphEngine:
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)

View File

@ -2,6 +2,8 @@
Factory for creating ReadyQueue instances from serialized state.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
@ -11,7 +13,7 @@ if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.

View File

@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
@ -27,7 +29,7 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
def from_node(cls, node: Node) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: "PluginAgentStrategy",
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]):
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> "InvokeCredentials":
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]):
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import importlib
import logging
import operator
@ -59,7 +61,7 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
node_type: ClassVar["NodeType"]
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@ -198,14 +200,14 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self.id = id
@ -241,7 +243,7 @@ class Node(Generic[NodeDataT]):
return
@property
def graph_init_params(self) -> "GraphInitParams":
def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@property
@ -457,7 +459,7 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.

View File

@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
@ -58,7 +60,7 @@ class Template:
segments: list[TemplateSegmentUnion]
@classmethod
def from_answer_template(cls, template_str: str) -> "Template":
def from_answer_template(cls, template_str: str) -> Template:
"""Create a Template from an Answer node template string.
Example:
@ -107,7 +109,7 @@ class Template:
return cls(segments=segments)
@classmethod
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import io
import json
@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]):
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_file_outputs: list[File]
_llm_file_saver: LLMFileSaver
@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]):
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]):
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]):
)
@staticmethod
def _image_file_to_markdown(file: "File", /):
def _image_file_to_markdown(file: File, /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]):
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list["File"] | None = None,
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_latency: float | None = None,
) -> ModelInvokeCompletedEvent:
@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]):
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
) -> File:
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]):
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import abc
from collections.abc import Mapping
from typing import Any, Protocol
@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> "DraftVariableSaver":
) -> DraftVariableSaver:
pass

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -267,6 +269,6 @@ class VariablePool(BaseModel):
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
@ -70,7 +72,7 @@ class SystemVariable(BaseModel):
return data
@classmethod
def empty(cls) -> "SystemVariable":
def empty(cls) -> SystemVariable:
return cls()
def to_dict(self) -> dict[SystemVariableKey, Any]:
@ -114,7 +116,7 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.TIMESTAMP] = self.timestamp
return d
def as_view(self) -> "SystemVariableReadOnlyView":
def as_view(self) -> SystemVariableReadOnlyView:
return SystemVariableReadOnlyView(self)