mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
refactor(typing): Fixup typing A2 - workflow engine & nodes (#31723)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, final
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file.file_manager import file_manager
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
@ -22,7 +23,6 @@ from core.workflow.nodes.template_transform.template_renderer import (
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
@ -47,9 +47,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
http_request_http_client: HttpClientProtocol = ssrf_proxy,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol = file_manager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
@ -68,12 +68,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._http_request_http_client = http_request_http_client
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
@ -82,23 +82,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config.get("id")
|
||||
if not is_str(node_id):
|
||||
raise ValueError("Node config missing id")
|
||||
node_id = node_config["id"]
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
|
||||
node_type_str = node_data.get("type")
|
||||
if not is_str(node_type_str):
|
||||
raise ValueError(f"Node {node_id} missing or invalid type information")
|
||||
|
||||
node_data = node_config["data"]
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
node_type = NodeType(node_data["type"])
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_type_str}")
|
||||
raise ValueError(f"Unknown node type: {node_data['type']}")
|
||||
|
||||
# Get node class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
|
||||
@ -168,3 +168,18 @@ def _to_url(f: File, /):
|
||||
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
class FileManager:
|
||||
"""
|
||||
Adapter exposing file manager helpers behind FileManagerProtocol.
|
||||
|
||||
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
|
||||
where a protocol-typed file manager is expected.
|
||||
"""
|
||||
|
||||
def download(self, f: File, /) -> bytes:
|
||||
return download(f)
|
||||
|
||||
|
||||
file_manager = FileManager()
|
||||
|
||||
@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC):
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls) -> DefaultConfig:
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
],
|
||||
"code_language": cls.get_language(),
|
||||
"code": cls.get_default_code(),
|
||||
"outputs": {"result": {"type": "string", "children": None}},
|
||||
},
|
||||
variables: list[VariableConfig] = [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
]
|
||||
outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}}
|
||||
|
||||
config: CodeConfig = {
|
||||
"variables": variables,
|
||||
"code_language": cls.get_language(),
|
||||
"code": cls.get_default_code(),
|
||||
"outputs": outputs,
|
||||
}
|
||||
return {"type": "code", "config": config}
|
||||
|
||||
@ -230,3 +230,41 @@ def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any)
|
||||
|
||||
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
class SSRFProxy:
|
||||
"""
|
||||
Adapter exposing SSRF-protected HTTP helpers behind HttpClientProtocol.
|
||||
|
||||
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
|
||||
where a protocol-typed HTTP client is expected.
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return get(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return head(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return post(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return put(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return delete(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
|
||||
@ -35,6 +35,7 @@ class SchemaRegistry:
|
||||
registry.load_all_versions()
|
||||
|
||||
cls._default_instance = registry
|
||||
return cls._default_instance
|
||||
|
||||
return cls._default_instance
|
||||
|
||||
|
||||
@ -189,16 +189,13 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
@ -300,18 +297,15 @@ class ToolManager:
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
|
||||
24
api/core/workflow/entities/graph_config.py
Normal file
24
api/core/workflow/entities/graph_config.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigData(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
data: NodeConfigData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
@ -5,15 +5,20 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
from libs.typing import is_str
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
@ -23,7 +28,7 @@ class NodeFactory(Protocol):
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
@ -63,28 +68,24 @@ class Graph:
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
||||
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, object]] = {}
|
||||
node_configs_map: dict[str, NodeConfigDict] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
node_configs_map[node_config["id"]] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: Mapping[str, Mapping[str, object]],
|
||||
node_configs_map: Mapping[str, NodeConfigDict],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
@ -113,10 +114,8 @@ class Graph:
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data")
|
||||
if not is_str_dict(node_data):
|
||||
continue
|
||||
node_type = node_data.get("type")
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
node_type = node_data["type"]
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
@ -176,7 +175,7 @@ class Graph:
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_configs_map: dict[str, NodeConfigDict],
|
||||
node_factory: NodeFactory,
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
@ -303,7 +302,7 @@ class Graph:
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .ready_queue import ReadyQueue
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -90,7 +89,7 @@ class GraphEngine:
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Execution Queues ===
|
||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||
self._ready_queue = self._graph_runtime_state.ready_queue
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
@ -15,10 +15,10 @@ from uuid import uuid4
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
@ -75,7 +75,7 @@ class ResponseStreamCoordinator:
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
|
||||
"""
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ class DefaultValue(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal, Self
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, Self] | None = None
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_type = DatasourceProviderType.value_of(datasource_type)
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
@ -11,9 +11,9 @@ import httpx
|
||||
from json_repair import repair_json
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -79,8 +79,8 @@ class Executor:
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@ -107,8 +107,8 @@ class Executor:
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
self._http_client = http_client
|
||||
self._file_manager = file_manager
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@ -336,7 +336,7 @@ class Executor:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
_METHOD_MAP = {
|
||||
_METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
|
||||
"get": self._http_client.get,
|
||||
"head": self._http_client.head,
|
||||
"post": self._http_client.post,
|
||||
@ -348,7 +348,7 @@ class Executor:
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
request_args: dict[str, Any] = {
|
||||
"data": self.data,
|
||||
"files": self.files,
|
||||
"json": self.json,
|
||||
@ -361,14 +361,13 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](
|
||||
response = _METHOD_MAP[method_lc](
|
||||
url=self.url,
|
||||
**request_args,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
|
||||
@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
http_client: HttpClientProtocol = ssrf_proxy,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
file_manager: FileManagerProtocol = file_manager,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._http_client = http_client
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
self._file_manager = file_manager
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
||||
@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return outputs
|
||||
|
||||
# Check if all non-None outputs are lists
|
||||
non_none_outputs = [output for output in outputs if output is not None]
|
||||
non_none_outputs: list[object] = [output for output in outputs if output is not None]
|
||||
if not non_none_outputs:
|
||||
return outputs
|
||||
|
||||
|
||||
@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
case "name":
|
||||
return lambda x: x.filename or ""
|
||||
case "type":
|
||||
return lambda x: x.type
|
||||
return lambda x: str(x.type)
|
||||
case "extension":
|
||||
return lambda x: x.extension or ""
|
||||
case "mime_type":
|
||||
return lambda x: x.mime_type or ""
|
||||
case "transfer_method":
|
||||
return lambda x: x.transfer_method
|
||||
return lambda x: str(x.transfer_method)
|
||||
case "url":
|
||||
return lambda x: x.remote_url or ""
|
||||
case "related_id":
|
||||
@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
elif key == "size" and isinstance(value, str):
|
||||
extract_func = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
|
||||
extract_number = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
@ -852,18 +852,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
prompt_content_type = type(prompt_content)
|
||||
if prompt_content_type == str:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
@ -873,13 +871,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if prompt_content_type == str:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
@ -1033,14 +1030,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if typed_node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
else:
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
if prompt_template.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol):
|
||||
@property
|
||||
def request_error(self) -> type[Exception]: ...
|
||||
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
|
||||
def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
|
||||
|
||||
|
||||
class FileManagerProtocol(Protocol):
|
||||
|
||||
@ -144,11 +144,11 @@ class WorkflowEntry:
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = dict(workflow.get_node_config_by_id(node_id))
|
||||
node_config_data = node_config.get("data", {})
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config_data = node_config["data"]
|
||||
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_type = NodeType(node_config_data["type"])
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
Reference in New Issue
Block a user