fix(tests): fix broken tests and linter issues

This commit is contained in:
QuantumGhost
2025-09-01 04:03:04 +08:00
parent 63c035d8a2
commit e2ae89e08d
25 changed files with 642 additions and 850 deletions

View File

@ -1,9 +1,10 @@
import dataclasses
import json
from collections.abc import Mapping
from typing import Any, TypeAlias
from enum import StrEnum
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, overload
from configs import dify_config
from core.file.models import File
from core.variables.segments import (
ArrayFileSegment,
ArraySegment,
@ -16,12 +17,60 @@ from core.variables.segments import (
Segment,
StringSegment,
)
from core.variables.utils import dumps_with_segments
LARGE_VARIABLE_THRESHOLD = 10 * 1024 # 100KB in bytes
OBJECT_CHAR_LIMIT = 5000
ARRAY_CHAR_LIMIT = 1000
_MAX_DEPTH = 100
_MAX_DEPTH = 20
class _QAKeys:
"""dict keys for _QAStructure"""
QA_CHUNKS = "qa_chunks"
QUESTION = "question"
ANSWER = "answer"
class _PCKeys:
"""dict keys for _ParentChildStructure"""
PARENT_MODE = "parent_mode"
PARENT_CHILD_CHUNKS = "parent_child_chunks"
PARENT_CONTENT = "parent_content"
CHILD_CONTENTS = "child_contents"
class _QAStructureItem(TypedDict):
question: str
answer: str
class _QAStructure(TypedDict):
qa_chunks: list[_QAStructureItem]
class _ParentChildChunkItem(TypedDict):
parent_content: str
child_contents: list[str]
class _ParentChildStructure(TypedDict):
parent_mode: str
parent_child_chunks: list[_ParentChildChunkItem]
class _SpecialChunkType(StrEnum):
parent_child = "parent_child"
qa = "qa"
_T = TypeVar("_T")
@dataclasses.dataclass(frozen=True)
class _PartResult(Generic[_T]):
value: _T
value_size: int
truncated: bool
class MaxDepthExceededError(Exception):
@ -51,13 +100,11 @@ class VariableTruncator:
Uses recursive size calculation to avoid repeated JSON serialization.
"""
_JSON_SEPARATORS = (",", ":")
def __init__(
self,
string_length_limit=5000,
array_element_limit: int = 20,
max_size_bytes: int = LARGE_VARIABLE_THRESHOLD,
max_size_bytes: int = 1024_000, # 100KB
):
if string_length_limit <= 3:
raise ValueError("string_length_limit should be greater than 3.")
@ -86,25 +133,24 @@ class VariableTruncator:
of a WorkflowNodeExecution record. This ensures the mappings remain within the
specified size limits while preserving their structure.
"""
size = self.calculate_json_size(v)
if size < self._max_size_bytes:
return v, False
budget = self._max_size_bytes
is_truncated = False
truncated_mapping: dict[str, Any] = {}
size = len(v.items())
remaining = size
length = len(v.items())
used_size = 0
for key, value in v.items():
budget -= self.calculate_json_size(key)
if budget < 0:
break
truncated_value, value_truncated = self._truncate_value_to_budget(value, budget // remaining)
if value_truncated:
is_truncated = True
truncated_mapping[key] = truncated_value
# TODO(QuantumGhost): This approach is inefficient. Ideally, the truncation function should directly
# report the size of the truncated value.
budget -= self.calculate_json_size(truncated_value) + 2 # ":" and ","
used_size += self.calculate_json_size(key)
if used_size > budget:
truncated_mapping[key] = "..."
continue
value_budget = (budget - used_size) // (length - len(truncated_mapping))
if isinstance(value, Segment):
part_result = self._truncate_segment(value, value_budget)
else:
part_result = self._truncate_json_primitives(value, value_budget)
is_truncated = is_truncated or part_result.truncated
truncated_mapping[key] = part_result.value
used_size += part_result.value_size
return truncated_mapping, is_truncated
@staticmethod
@ -125,6 +171,27 @@ class VariableTruncator:
return True
def truncate(self, segment: Segment) -> TruncationResult:
if isinstance(segment, StringSegment):
result = self._truncate_segment(segment, self._string_length_limit)
else:
result = self._truncate_segment(segment, self._max_size_bytes)
if result.value_size > self._max_size_bytes:
if isinstance(result.value, str):
result = self._truncate_string(result.value, self._max_size_bytes)
return TruncationResult(StringSegment(value=result.value), True)
# Apply final fallback - convert to JSON string and truncate
json_str = dumps_with_segments(result.value, ensure_ascii=False)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
return TruncationResult(
result=segment.model_copy(update={"value": result.value.value}), truncated=result.truncated
)
def _truncate_segment(self, segment: Segment, target_size: int) -> _PartResult[Segment]:
"""
Apply smart truncation to a variable value.
@ -136,43 +203,38 @@ class VariableTruncator:
"""
if not VariableTruncator._segment_need_truncation(segment):
return TruncationResult(result=segment, truncated=False)
return _PartResult(segment, self.calculate_json_size(segment.value), False)
result: _PartResult[Any]
# Apply type-specific truncation with target size
if isinstance(segment, ArraySegment):
truncated_value, was_truncated = self._truncate_array(segment.value, self._max_size_bytes)
result = self._truncate_array(segment.value, target_size)
elif isinstance(segment, StringSegment):
truncated_value, was_truncated = self._truncate_string(segment.value)
result = self._truncate_string(segment.value, target_size)
elif isinstance(segment, ObjectSegment):
truncated_value, was_truncated = self._truncate_object(segment.value, self._max_size_bytes)
result = self._truncate_object(segment.value, target_size)
else:
raise AssertionError("this should be unreachable.")
# Check if we still exceed the final character limit after type-specific truncation
if not was_truncated:
return TruncationResult(result=segment, truncated=False)
truncated_size = self.calculate_json_size(truncated_value)
if truncated_size > self._max_size_bytes:
if isinstance(truncated_value, str):
return TruncationResult(StringSegment(value=truncated_value[: self._max_size_bytes - 3]), True)
# Apply final fallback - convert to JSON string and truncate
json_str = json.dumps(truncated_value, ensure_ascii=False, separators=self._JSON_SEPARATORS)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
return TruncationResult(result=segment.model_copy(update={"value": truncated_value}), truncated=True)
return _PartResult(
value=segment.model_copy(update={"value": result.value}),
value_size=result.value_size,
truncated=result.truncated,
)
@staticmethod
def calculate_json_size(value: Any, depth=0) -> int:
"""Recursively calculate JSON size without serialization."""
if isinstance(value, Segment):
return VariableTruncator.calculate_json_size(value.value)
if depth > _MAX_DEPTH:
raise MaxDepthExceededError()
if isinstance(value, str):
# For strings, we need to account for escaping and quotes
# Rough estimate: each character might need escaping, plus 2 for quotes
return len(value.encode("utf-8")) + 2
# Ideally, the size of strings should be calculated based on their utf-8 encoded length.
# However, this adds complexity as we would need to compute encoded sizes consistently
# throughout the code. Therefore, we approximate the size using the string's length.
# Rough estimate: number of characters, plus 2 for quotes
return len(value) + 2
elif isinstance(value, (int, float)):
return len(str(value))
elif isinstance(value, bool):
@ -197,60 +259,73 @@ class VariableTruncator:
total += 1 # ":"
total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
return total
elif isinstance(value, File):
return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
else:
raise UnknownTypeError(f"got unknown type {type(value)}")
def _truncate_string(self, value: str) -> tuple[str, bool]:
"""Truncate string values."""
if len(value) <= self._string_length_limit:
return value, False
return value[: self._string_length_limit - 3] + "...", True
def _truncate_string(self, value: str, target_size: int) -> _PartResult[str]:
if (size := self.calculate_json_size(value)) < target_size:
return _PartResult(value, size, False)
if target_size < 5:
return _PartResult("...", 5, True)
truncated_size = min(self._string_length_limit, target_size - 5)
truncated_value = value[:truncated_size] + "..."
return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True)
def _truncate_array(self, value: list, target_size: int) -> tuple[list, bool]:
def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]:
"""
Truncate array with correct strategy:
1. First limit to 20 items
2. If still too large, truncate individual items
"""
# Step 1: Limit to first 20 items
limited_items = value[: self._array_element_limit]
was_truncated = len(limited_items) < len(value)
truncated_value: list[Any] = []
truncated = False
used_size = self.calculate_json_size([])
# Step 2: Check if we still exceed the target size
current_size = self.calculate_json_size(limited_items)
if current_size <= target_size:
return limited_items, was_truncated
target_length = self._array_element_limit
# Step 3: Truncate individual items to fit within target size
truncated_items = []
remaining_size = target_size - 2 # Account for []
for i, item in enumerate(limited_items):
for i, item in enumerate(value):
if i >= target_length:
return _PartResult(truncated_value, used_size, True)
if i > 0:
remaining_size -= 1 # Account for comma
used_size += 1 # Account for comma
if remaining_size <= 0:
if used_size > target_size:
break
# Calculate how much space this item can use
remaining_items = len(limited_items) - i
item_budget = remaining_size // remaining_items
part_result = self._truncate_json_primitives(item, target_size - used_size)
truncated_value.append(part_result.value)
used_size += part_result.value_size
truncated = part_result.truncated
return _PartResult(truncated_value, used_size, truncated)
# Truncate the item to fit within budget
truncated_item, item_truncated = self._truncate_item_to_budget(item, item_budget)
truncated_items.append(truncated_item)
@classmethod
def _maybe_qa_structure(cls, m: Mapping[str, Any]) -> bool:
qa_chunks = m.get(_QAKeys.QA_CHUNKS)
if qa_chunks is None:
return False
if not isinstance(qa_chunks, list):
return False
return True
# Update remaining size
item_size = self.calculate_json_size(truncated_item)
remaining_size -= item_size
@classmethod
def _maybe_parent_child_structure(cls, m: Mapping[str, Any]) -> bool:
parent_mode = m.get(_PCKeys.PARENT_MODE)
if parent_mode is None:
return False
if not isinstance(parent_mode, str):
return False
parent_child_chunks = m.get(_PCKeys.PARENT_CHILD_CHUNKS)
if parent_child_chunks is None:
return False
if not isinstance(parent_child_chunks, list):
return False
if item_truncated:
was_truncated = True
return True
return truncated_items, True
def _truncate_object(self, value: Mapping[str, Any], target_size: int) -> tuple[Mapping[str, Any], bool]:
def _truncate_object(self, mapping: Mapping[str, Any], target_size: int) -> _PartResult[Mapping[str, Any]]:
"""
Truncate object with key preservation priority.
@ -258,91 +333,87 @@ class VariableTruncator:
1. Keep all keys, truncate values to fit within budget
2. If still too large, drop keys starting from the end
"""
if not value:
return value, False
if not mapping:
return _PartResult(mapping, self.calculate_json_size(mapping), False)
truncated_obj = {}
was_truncated = False
remaining_size = target_size - 2 # Account for {}
truncated = False
used_size = self.calculate_json_size({})
# Sort keys to ensure deterministic behavior
sorted_keys = sorted(value.keys())
sorted_keys = sorted(mapping.keys())
for i, key in enumerate(sorted_keys):
val = value[key]
if i > 0:
remaining_size -= 1 # Account for comma
if remaining_size <= 0:
if used_size > target_size:
# No more room for additional key-value pairs
was_truncated = True
truncated = True
break
pair_size = 0
if i > 0:
pair_size += 1 # Account for comma
# Calculate budget for this key-value pair
key_size = self.calculate_json_size(str(key)) + 1 # +1 for ":"
# do not try to truncate keys, as we want to keep the structure of
# object.
key_size = self.calculate_json_size(key) + 1 # +1 for ":"
pair_size += key_size
remaining_pairs = len(sorted_keys) - i
value_budget = max(0, (remaining_size - key_size) // remaining_pairs)
value_budget = max(0, (target_size - pair_size - used_size) // remaining_pairs)
if value_budget <= 0:
was_truncated = True
truncated = True
break
# Truncate the value to fit within budget
truncated_val, val_truncated = self._truncate_value_to_budget(val, value_budget)
truncated_obj[key] = truncated_val
if val_truncated:
was_truncated = True
# Update remaining size
pair_size = key_size + self.calculate_json_size(truncated_val)
remaining_size -= pair_size
return truncated_obj, was_truncated or len(truncated_obj) < len(value)
def _truncate_item_to_budget(self, item: Any, budget: int) -> tuple[Any, bool]:
"""Truncate an array item to fit within a size budget."""
if isinstance(item, str):
# For strings, truncate to fit within budget (accounting for quotes)
max_chars = max(0, budget - 5) # -5 for quotes and potential "..."
max_chars = min(max_chars, ARRAY_CHAR_LIMIT)
if len(item) <= max_chars:
return item, False
return item[:max_chars] + "...", True
elif isinstance(item, dict):
# For objects, recursively truncate
return self._truncate_object(item, budget)
elif isinstance(item, list):
# For nested arrays, recursively truncate
return self._truncate_array(item, budget)
else:
# For other types, check if they fit
item_size = self.calculate_json_size(item)
if item_size <= budget:
return item, False
value = mapping[key]
if isinstance(value, Segment):
value_result = self._truncate_segment(value, value_budget)
else:
# Convert to string and truncate
str_item = str(item)
return self._truncate_item_to_budget(str_item, budget)
value_result = self._truncate_json_primitives(mapping[key], value_budget)
def _truncate_value_to_budget(self, val: Any, budget: int) -> tuple[Any, bool]:
truncated_obj[key] = value_result.value
pair_size += value_result.value_size
used_size += pair_size
if value_result.truncated:
truncated = True
return _PartResult(truncated_obj, used_size, truncated)
@overload
def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
@overload
def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ...
@overload
def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
@overload
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ...
@overload
def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
@overload
def _truncate_json_primitives(self, val: float, target_size: int) -> _PartResult[float]: ...
@overload
def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
def _truncate_json_primitives(
self, val: str | list | dict | bool | int | float | None, target_size: int
) -> _PartResult[Any]:
"""Truncate a value within an object to fit within budget."""
if isinstance(val, str):
# For strings, respect OBJECT_CHAR_LIMIT but also budget
max_chars = min(OBJECT_CHAR_LIMIT, max(0, budget - 5)) # -5 for quotes and "..."
if len(val) <= max_chars:
return val, False
return val[:max_chars] + "...", True
return self._truncate_string(val, target_size)
elif isinstance(val, list):
return self._truncate_array(val, budget)
return self._truncate_array(val, target_size)
elif isinstance(val, dict):
return self._truncate_object(val, budget)
return self._truncate_object(val, target_size)
elif val is None or isinstance(val, (bool, int, float)):
return _PartResult(val, self.calculate_json_size(val), False)
else:
# For other types, check if they fit
val_size = self.calculate_json_size(val)
if val_size <= budget:
return val, False
else:
# Convert to string and truncate
return self._truncate_value_to_budget(str(val), budget)
raise AssertionError("this statement should be unreachable.")