mirror of
https://github.com/langgenius/dify.git
synced 2026-03-27 01:00:13 +08:00
346 lines
14 KiB
Python
346 lines
14 KiB
Python
from collections.abc import Callable, Sequence
|
|
from typing import Any, TypeAlias, TypeVar
|
|
|
|
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
|
from graphon.file import File
|
|
from graphon.node_events import NodeRunResult
|
|
from graphon.nodes.base.node import Node
|
|
from graphon.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
|
from graphon.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
|
|
|
from .entities import FilterOperator, ListOperatorNodeData, Order
|
|
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
|
|
|
_SUPPORTED_TYPES_TUPLE = (
|
|
ArrayFileSegment,
|
|
ArrayNumberSegment,
|
|
ArrayStringSegment,
|
|
ArrayBooleanSegment,
|
|
)
|
|
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
|
"""Returns the negation of a given filter function. If the original filter
|
|
returns `True` for a value, the negated filter will return `False`, and vice versa.
|
|
"""
|
|
|
|
def wrapper(value: _T) -> bool:
|
|
return not filter_(value)
|
|
|
|
return wrapper
|
|
|
|
|
|
class ListOperatorNode(Node[ListOperatorNodeData]):
|
|
node_type = BuiltinNodeTypes.LIST_OPERATOR
|
|
|
|
@classmethod
|
|
def version(cls) -> str:
|
|
return "1"
|
|
|
|
def _run(self):
|
|
inputs: dict[str, Sequence[object]] = {}
|
|
process_data: dict[str, Sequence[object]] = {}
|
|
outputs: dict[str, Any] = {}
|
|
|
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
|
if variable is None:
|
|
error_message = f"Variable not found for selector: {self.node_data.variable}"
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
|
)
|
|
if not variable.value:
|
|
inputs = {"variable": []}
|
|
process_data = {"variable": []}
|
|
if isinstance(variable, ArraySegment):
|
|
result = variable.model_copy(update={"value": []})
|
|
else:
|
|
result = ArrayAnySegment(value=[])
|
|
outputs = {"result": result, "first_record": None, "last_record": None}
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=inputs,
|
|
process_data=process_data,
|
|
outputs=outputs,
|
|
)
|
|
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
|
error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}"
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
|
)
|
|
|
|
if isinstance(variable, ArrayFileSegment):
|
|
inputs = {"variable": [item.to_dict() for item in variable.value]}
|
|
process_data["variable"] = [item.to_dict() for item in variable.value]
|
|
else:
|
|
inputs = {"variable": variable.value}
|
|
process_data["variable"] = variable.value
|
|
|
|
try:
|
|
# Filter
|
|
if self.node_data.filter_by.enabled:
|
|
variable = self._apply_filter(variable)
|
|
|
|
# Extract
|
|
if self.node_data.extract_by.enabled:
|
|
variable = self._extract_slice(variable)
|
|
|
|
# Order
|
|
if self.node_data.order_by.enabled:
|
|
variable = self._apply_order(variable)
|
|
|
|
# Slice
|
|
if self.node_data.limit.enabled:
|
|
variable = self._apply_slice(variable)
|
|
|
|
outputs = {
|
|
"result": variable,
|
|
"first_record": variable.value[0] if variable.value else None,
|
|
"last_record": variable.value[-1] if variable.value else None,
|
|
}
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=inputs,
|
|
process_data=process_data,
|
|
outputs=outputs,
|
|
)
|
|
except ListOperatorError as e:
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
error=str(e),
|
|
inputs=inputs,
|
|
process_data=process_data,
|
|
outputs=outputs,
|
|
)
|
|
|
|
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
|
filter_func: Callable[[Any], bool]
|
|
result: list[Any] = []
|
|
for condition in self.node_data.filter_by.conditions:
|
|
if isinstance(variable, ArrayStringSegment):
|
|
if not isinstance(condition.value, str):
|
|
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
|
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
|
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
|
|
result = list(filter(filter_func, variable.value))
|
|
variable = variable.model_copy(update={"value": result})
|
|
elif isinstance(variable, ArrayNumberSegment):
|
|
if not isinstance(condition.value, str):
|
|
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
|
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
|
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
|
|
result = list(filter(filter_func, variable.value))
|
|
variable = variable.model_copy(update={"value": result})
|
|
elif isinstance(variable, ArrayFileSegment):
|
|
if isinstance(condition.value, str):
|
|
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
|
elif isinstance(condition.value, bool):
|
|
raise ValueError(f"File filter expects a string value, got {type(condition.value)}")
|
|
else:
|
|
value = condition.value
|
|
filter_func = _get_file_filter_func(
|
|
key=condition.key,
|
|
condition=condition.comparison_operator,
|
|
value=value,
|
|
)
|
|
result = list(filter(filter_func, variable.value))
|
|
variable = variable.model_copy(update={"value": result})
|
|
else:
|
|
if not isinstance(condition.value, bool):
|
|
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
|
|
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
|
result = list(filter(filter_func, variable.value))
|
|
variable = variable.model_copy(update={"value": result})
|
|
return variable
|
|
|
|
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
|
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
|
result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC)
|
|
variable = variable.model_copy(update={"value": result})
|
|
else:
|
|
result = _order_file(
|
|
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
|
)
|
|
variable = variable.model_copy(update={"value": result})
|
|
|
|
return variable
|
|
|
|
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
|
result = variable.value[: self.node_data.limit.size]
|
|
return variable.model_copy(update={"value": result})
|
|
|
|
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
|
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
|
|
if value < 1:
|
|
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
|
if value > len(variable.value):
|
|
raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}")
|
|
value -= 1
|
|
result = variable.value[value]
|
|
return variable.model_copy(update={"value": [result]})
|
|
|
|
|
|
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
|
|
match key:
|
|
case "size":
|
|
return lambda x: x.size
|
|
case _:
|
|
raise InvalidKeyError(f"Invalid key: {key}")
|
|
|
|
|
|
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
|
match key:
|
|
case "name":
|
|
return lambda x: x.filename or ""
|
|
case "type":
|
|
return lambda x: str(x.type)
|
|
case "extension":
|
|
return lambda x: x.extension or ""
|
|
case "mime_type":
|
|
return lambda x: x.mime_type or ""
|
|
case "transfer_method":
|
|
return lambda x: str(x.transfer_method)
|
|
case "url":
|
|
return lambda x: x.remote_url or ""
|
|
case "related_id":
|
|
return lambda x: x.related_id or ""
|
|
case _:
|
|
raise InvalidKeyError(f"Invalid key: {key}")
|
|
|
|
|
|
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
|
|
match condition:
|
|
case "contains":
|
|
return _contains(value)
|
|
case "start with":
|
|
return _startswith(value)
|
|
case "end with":
|
|
return _endswith(value)
|
|
case "is":
|
|
return _is(value)
|
|
case "in":
|
|
return _in(value)
|
|
case "empty":
|
|
return lambda x: x == ""
|
|
case "not contains":
|
|
return _negation(_contains(value))
|
|
case "is not":
|
|
return _negation(_is(value))
|
|
case "not in":
|
|
return _negation(_in(value))
|
|
case "not empty":
|
|
return lambda x: x != ""
|
|
case _:
|
|
raise InvalidConditionError(f"Invalid condition: {condition}")
|
|
|
|
|
|
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
|
|
match condition:
|
|
case "in":
|
|
return _in(value)
|
|
case "not in":
|
|
return _negation(_in(value))
|
|
case _:
|
|
raise InvalidConditionError(f"Invalid condition: {condition}")
|
|
|
|
|
|
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
|
|
match condition:
|
|
case "=":
|
|
return _eq(value)
|
|
case "≠":
|
|
return _ne(value)
|
|
case "<":
|
|
return _lt(value)
|
|
case "≤":
|
|
return _le(value)
|
|
case ">":
|
|
return _gt(value)
|
|
case "≥":
|
|
return _ge(value)
|
|
case _:
|
|
raise InvalidConditionError(f"Invalid condition: {condition}")
|
|
|
|
|
|
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
|
|
match condition:
|
|
case FilterOperator.IS:
|
|
return _is(value)
|
|
case FilterOperator.IS_NOT:
|
|
return _negation(_is(value))
|
|
case _:
|
|
raise InvalidConditionError(f"Invalid condition: {condition}")
|
|
|
|
|
|
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
|
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
|
|
extract_func = _get_file_extract_string_func(key=key)
|
|
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
|
if key in {"type", "transfer_method"}:
|
|
extract_func = _get_file_extract_string_func(key=key)
|
|
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
|
elif key == "size" and isinstance(value, str):
|
|
extract_number = _get_file_extract_number_func(key=key)
|
|
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
|
|
else:
|
|
raise InvalidKeyError(f"Invalid key: {key}")
|
|
|
|
|
|
def _contains(value: str) -> Callable[[str], bool]:
|
|
return lambda x: value in x
|
|
|
|
|
|
def _startswith(value: str) -> Callable[[str], bool]:
|
|
return lambda x: x.startswith(value)
|
|
|
|
|
|
def _endswith(value: str) -> Callable[[str], bool]:
|
|
return lambda x: x.endswith(value)
|
|
|
|
|
|
def _is(value: _T) -> Callable[[_T], bool]:
|
|
return lambda x: x == value
|
|
|
|
|
|
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
|
|
return lambda x: x in value
|
|
|
|
|
|
def _eq(value: int | float) -> Callable[[int | float], bool]:
|
|
return lambda x: x == value
|
|
|
|
|
|
def _ne(value: int | float) -> Callable[[int | float], bool]:
|
|
return lambda x: x != value
|
|
|
|
|
|
def _lt(value: int | float) -> Callable[[int | float], bool]:
|
|
return lambda x: x < value
|
|
|
|
|
|
def _le(value: int | float) -> Callable[[int | float], bool]:
|
|
return lambda x: x <= value
|
|
|
|
|
|
def _gt(value: int | float) -> Callable[[int | float], bool]:
|
|
return lambda x: x > value
|
|
|
|
|
|
def _ge(value: int | float) -> Callable[[int | float], bool]:
|
|
return lambda x: x >= value
|
|
|
|
|
|
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
|
extract_func: Callable[[File], Any]
|
|
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}:
|
|
extract_func = _get_file_extract_string_func(key=order_by)
|
|
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
|
elif order_by == "size":
|
|
extract_func = _get_file_extract_number_func(key=order_by)
|
|
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
|
else:
|
|
raise InvalidKeyError(f"Invalid order key: {order_by}")
|