Files
dify/api/graphon/nodes/list_operator/node.py

346 lines
14 KiB
Python

from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, TypeVar
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.file import File
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
from graphon.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from graphon.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
_SUPPORTED_TYPES_TUPLE = (
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
ArrayBooleanSegment,
)
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
_T = TypeVar("_T")
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
"""Returns the negation of a given filter function. If the original filter
returns `True` for a value, the negated filter will return `False`, and vice versa.
"""
def wrapper(value: _T) -> bool:
return not filter_(value)
return wrapper
class ListOperatorNode(Node[ListOperatorNodeData]):
node_type = BuiltinNodeTypes.LIST_OPERATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
inputs: dict[str, Sequence[object]] = {}
process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self.node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
if isinstance(variable, ArraySegment):
result = variable.model_copy(update={"value": []})
else:
result = ArrayAnySegment(value=[])
outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value
try:
# Filter
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
if self.node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
"result": variable,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except ListOperatorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
elif isinstance(condition.value, bool):
raise ValueError(f"File filter expects a string value, got {type(condition.value)}")
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
if not isinstance(condition.value, bool):
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
return variable
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC)
variable = variable.model_copy(update={"value": result})
else:
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self.node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
if value > len(variable.value):
raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}")
value -= 1
result = variable.value[value]
return variable.model_copy(update={"value": [result]})
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
match key:
case "size":
return lambda x: x.size
case _:
raise InvalidKeyError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
match key:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: str(x.type)
case "extension":
return lambda x: x.extension or ""
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: str(x.transfer_method)
case "url":
return lambda x: x.remote_url or ""
case "related_id":
return lambda x: x.related_id or ""
case _:
raise InvalidKeyError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
match condition:
case "contains":
return _contains(value)
case "start with":
return _startswith(value)
case "end with":
return _endswith(value)
case "is":
return _is(value)
case "in":
return _in(value)
case "empty":
return lambda x: x == ""
case "not contains":
return _negation(_contains(value))
case "is not":
return _negation(_is(value))
case "not in":
return _negation(_in(value))
case "not empty":
return lambda x: x != ""
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
match condition:
case "in":
return _in(value)
case "not in":
return _negation(_in(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
match condition:
case "=":
return _eq(value)
case "":
return _ne(value)
case "<":
return _lt(value)
case "":
return _le(value)
case ">":
return _gt(value)
case "":
return _ge(value)
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
match condition:
case FilterOperator.IS:
return _is(value)
case FilterOperator.IS_NOT:
return _negation(_is(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
if key in {"type", "transfer_method"}:
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
extract_number = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
else:
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str) -> Callable[[str], bool]:
return lambda x: value in x
def _startswith(value: str) -> Callable[[str], bool]:
return lambda x: x.startswith(value)
def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: _T) -> Callable[[_T], bool]:
return lambda x: x == value
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
return lambda x: x in value
def _eq(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x == value
def _ne(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x != value
def _lt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x < value
def _le(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x <= value
def _gt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x > value
def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
else:
raise InvalidKeyError(f"Invalid order key: {order_by}")