Merge main HEAD (segment 5) into sandboxed-agent-rebase

Resolve 83 conflicts: 10 backend, 62 frontend, 11 config/lock files.
Preserve sandbox/agent/collaboration features while adopting main's
UI refactorings (Dialog/AlertDialog/Popover), model provider updates,
and enterprise features.

Made-with: Cursor
This commit is contained in:
Novice
2026-03-23 14:20:06 +08:00
1671 changed files with 124822 additions and 22302 deletions

View File

@ -234,6 +234,7 @@ class TestAdvancedChatAppGeneratorInternals:
captured: dict[str, object] = {}
prefill_calls: list[object] = []
var_loader = SimpleNamespace(loader="draft")
workflow = SimpleNamespace(id="workflow-id")
monkeypatch.setattr(
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
@ -260,8 +261,8 @@ class TestAdvancedChatAppGeneratorInternals:
def __init__(self, session):
_ = session
def prefill_conversation_variable_default_values(self, workflow):
prefill_calls.append(workflow)
def prefill_conversation_variable_default_values(self, workflow, user_id):
prefill_calls.append((workflow, user_id))
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
@ -273,7 +274,7 @@ class TestAdvancedChatAppGeneratorInternals:
result = generator.single_iteration_generate(
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
workflow=SimpleNamespace(id="workflow-id"),
workflow=workflow,
node_id="node-1",
user=SimpleNamespace(id="user-id"),
args={"inputs": {"foo": "bar"}},
@ -281,7 +282,7 @@ class TestAdvancedChatAppGeneratorInternals:
)
assert result == {"ok": True}
assert prefill_calls
assert prefill_calls == [(workflow, "user-id")]
assert captured["variable_loader"] is var_loader
assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1"
@ -291,6 +292,7 @@ class TestAdvancedChatAppGeneratorInternals:
captured: dict[str, object] = {}
prefill_calls: list[object] = []
var_loader = SimpleNamespace(loader="draft")
workflow = SimpleNamespace(id="workflow-id")
monkeypatch.setattr(
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
@ -317,8 +319,8 @@ class TestAdvancedChatAppGeneratorInternals:
def __init__(self, session):
_ = session
def prefill_conversation_variable_default_values(self, workflow):
prefill_calls.append(workflow)
def prefill_conversation_variable_default_values(self, workflow, user_id):
prefill_calls.append((workflow, user_id))
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
@ -330,7 +332,7 @@ class TestAdvancedChatAppGeneratorInternals:
result = generator.single_loop_generate(
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
workflow=SimpleNamespace(id="workflow-id"),
workflow=workflow,
node_id="node-2",
user=SimpleNamespace(id="user-id"),
args=SimpleNamespace(inputs={"foo": "bar"}),
@ -338,7 +340,7 @@ class TestAdvancedChatAppGeneratorInternals:
)
assert result == {"ok": True}
assert prefill_calls
assert prefill_calls == [(workflow, "user-id")]
assert captured["variable_loader"] is var_loader
assert captured["application_generate_entity"].single_loop_run.node_id == "node-2"

View File

@ -44,11 +44,22 @@ class TestAgentChatAppGenerateResponseConverterBlocking:
metadata={
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s1",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "content",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
}
],
"annotation_reply": {"id": "a"},
@ -107,11 +118,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
metadata={
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s1",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "content",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
"summary": "summary",
"extra": "ignored",
}
@ -151,11 +173,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
assert "usage" not in metadata
assert metadata["retriever_resources"] == [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s1",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "content",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
"summary": "summary",
}
]

View File

@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
import uuid
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from unittest.mock import Mock
@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_finish_response_prefers_event_finished_at(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Finished timestamps should come from the event, not delayed queue processing time."""
converter = self.create_workflow_response_converter()
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr(
"core.app.apps.common.workflow_response_converter.naive_utc_now",
lambda: delayed_processing_time,
)
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
event = QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=BuiltinNodeTypes.CODE,
node_execution_id="node-exec-1",
start_at=start_at,
finished_at=finished_at,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
assert response is not None
assert response.data.elapsed_time == 2.0
assert response.data.finished_at == int(finished_at.timestamp())
def test_workflow_node_retry_response_uses_truncated_process_data(self):
"""Test that node retry response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()

View File

@ -38,11 +38,22 @@ class TestCompletionAppGenerateResponseConverter:
metadata = {
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset 1",
"document_id": "document-1",
"segment_id": "s",
"position": 1,
"data_source_type": "file",
"document_name": "doc",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "abc1234",
"content": "c",
"page": 5,
"title": "Citation Title",
"files": [{"id": "file-1"}],
"summary": "sum",
"extra": "x",
}
@ -66,7 +77,12 @@ class TestCompletionAppGenerateResponseConverter:
assert "annotation_reply" not in result["metadata"]
assert "usage" not in result["metadata"]
assert result["metadata"]["retriever_resources"][0]["dataset_id"] == "dataset-1"
assert result["metadata"]["retriever_resources"][0]["document_id"] == "document-1"
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
assert result["metadata"]["retriever_resources"][0]["data_source_type"] == "file"
assert result["metadata"]["retriever_resources"][0]["segment_position"] == 3
assert result["metadata"]["retriever_resources"][0]["index_node_hash"] == "abc1234"
assert "extra" not in result["metadata"]["retriever_resources"][0]
def test_convert_blocking_simple_response_metadata_not_dict(self):

View File

@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.task_pipeline import message_cycle_manager
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from models.enums import ConversationFromSource
from models.model import AppMode, Conversation, Message
@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation():
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id="user-id",
from_account_id=None,
)

View File

@ -0,0 +1,60 @@
from datetime import UTC, datetime
from unittest.mock import Mock
import pytest
from core.app.workflow.layers.persistence import (
PersistenceWorkflowInfo,
WorkflowPersistenceLayer,
_NodeRuntimeSnapshot,
)
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
from dify_graph.node_events import NodeRunResult
def _build_layer() -> WorkflowPersistenceLayer:
application_generate_entity = Mock()
application_generate_entity.inputs = {}
return WorkflowPersistenceLayer(
application_generate_entity=application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
version="1",
graph_data={},
),
workflow_execution_repository=Mock(),
workflow_node_execution_repository=Mock(),
)
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
layer = _build_layer()
node_execution = Mock()
node_execution.id = "node-exec-1"
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
node_execution.update_from_mapping = Mock()
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
node_id="node-id",
title="LLM",
predecessor_node_id=None,
iteration_id="iter-1",
loop_id=None,
created_at=node_execution.created_at,
)
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
layer._update_node_execution(
node_execution,
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event_finished_at,
)
assert node_execution.finished_at == event_finished_at
assert node_execution.elapsed_time == 2.0

View File

@ -166,6 +166,7 @@ class TestDatasourceFileManager:
# Setup
mock_guess_ext.return_value = None # Cannot guess
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_config.STORAGE_TYPE = "local"
# Execute
upload_file = DatasourceFileManager.create_file_by_raw(

View File

@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
from models.enums import CredentialSourceType
from models.provider import ProviderType
from models.provider_ids import ModelProviderID
@ -409,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() -
configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session)
assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value
assert existing_record.preferred_provider_type == ProviderType.SYSTEM
session.commit.assert_called_once()
@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
id="lb-base",
name="LB Base",
credentials={},
credential_source_type="provider",
credential_source_type=CredentialSourceType.PROVIDER,
)
],
),
@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
id="lb-custom",
name="LB Custom",
credentials={},
credential_source_type="custom_model",
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
)
],
),
@ -734,7 +735,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No
def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
configuration = _build_provider_configuration()
session = Mock()
provider_record = SimpleNamespace(is_valid=False)
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred")
with _patched_session(session):
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
@ -743,6 +744,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
configuration.create_provider_credential({"api_key": "raw"}, "Main")
assert provider_record.is_valid is True
assert provider_record.credential_id == "existing-cred"
session.commit.assert_called_once()
def test_create_provider_credential_auto_activates_when_no_active_credential() -> None:
configuration = _build_provider_configuration()
session = Mock()
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None)
with _patched_session(session):
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}):
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record):
with patch("core.entities.provider_configuration.ProviderCredentialsCache"):
with patch.object(ProviderConfiguration, "switch_preferred_provider_type"):
configuration.create_provider_credential({"api_key": "raw"}, "Main")
assert provider_record.is_valid is True
assert provider_record.credential_id is not None
session.commit.assert_called_once()
@ -807,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None:
configuration._update_load_balancing_configs_with_credential(
credential_id="cred-1",
credential_record=credential_record,
credential_source="provider",
credential_source=CredentialSourceType.PROVIDER,
session=session,
)
@ -825,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non
configuration._update_load_balancing_configs_with_credential(
credential_id="cred-1",
credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"),
credential_source="provider",
credential_source=CredentialSourceType.PROVIDER,
session=session,
)

View File

@ -801,6 +801,27 @@ class TestAuthOrchestration:
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com")
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]
def test_build_protected_resource_metadata_discovery_urls_with_relative_hint(self):
urls = build_protected_resource_metadata_discovery_urls(
"/.well-known/oauth-protected-resource/tenant/mcp",
"https://api.example.com/tenant/mcp",
)
assert urls == [
"https://api.example.com/.well-known/oauth-protected-resource/tenant/mcp",
"https://api.example.com/.well-known/oauth-protected-resource",
]
def test_build_protected_resource_metadata_discovery_urls_ignores_scheme_less_hint(self):
urls = build_protected_resource_metadata_discovery_urls(
"/openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource/tenant/mcp",
"https://openapi-mcp.cn-hangzhou.aliyuncs.com/tenant/mcp",
)
assert urls == [
"https://openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource/tenant/mcp",
"https://openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource",
]
def test_build_oauth_authorization_server_metadata_discovery_urls(self):
# Case 1: with auth_server_url
urls = build_oauth_authorization_server_metadata_discovery_urls(

View File

@ -0,0 +1,181 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
from models.api_based_extension import APIBasedExtension
class TestApiModeration:
@pytest.fixture
def api_config(self):
return {
"inputs_config": {
"enabled": True,
},
"outputs_config": {
"enabled": True,
},
"api_based_extension_id": "test-extension-id",
}
@pytest.fixture
def api_moderation(self, api_config):
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
def test_moderation_input_params(self):
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
assert params.app_id == "app-1"
assert params.inputs == {"key": "val"}
assert params.query == "test query"
# Test defaults
params_default = ModerationInputParams()
assert params_default.app_id == ""
assert params_default.inputs == {}
assert params_default.query == ""
def test_moderation_output_params(self):
params = ModerationOutputParams(app_id="app-1", text="test text")
assert params.app_id == "app-1"
assert params.text == "test text"
with pytest.raises(ValidationError):
ModerationOutputParams()
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_success(self, mock_get_extension, api_config):
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
ApiModeration.validate_config("test-tenant-id", api_config)
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
def test_validate_config_missing_extension_id(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": True},
}
with pytest.raises(ValueError, match="api_based_extension_id is required"):
ApiModeration.validate_config("test-tenant-id", config)
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
mock_get_extension.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
ApiModeration.validate_config("test-tenant-id", api_config)
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
assert isinstance(result, ModerationInputsResult)
assert result.flagged is True
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Blocked by API"
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_INPUT,
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
)
def test_moderation_for_inputs_disabled(self):
config = {
"inputs_config": {"enabled": False},
"outputs_config": {"enabled": True},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_inputs(inputs={}, query="")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == ""
def test_moderation_for_inputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_inputs({}, "")
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
result = api_moderation.moderation_for_outputs(text="hello world")
assert isinstance(result, ModerationOutputsResult)
assert result.flagged is False
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
)
def test_moderation_for_outputs_disabled(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": False},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_outputs(text="test")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
def test_moderation_for_outputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_outputs("test")
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
@patch("core.moderation.api.api.decrypt_token")
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_ext.api_endpoint = "http://api.test"
mock_ext.api_key = "encrypted-key"
mock_get_ext.return_value = mock_ext
mock_decrypt.return_value = "decrypted-key"
mock_requestor = MagicMock()
mock_requestor.request.return_value = {"flagged": True}
mock_requestor_cls.return_value = mock_requestor
params = {"some": "params"}
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
assert result == {"flagged": True}
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
def test_get_config_by_requestor_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
mock_get_ext.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.db.session.scalar")
def test_get_api_based_extension(self, mock_scalar):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_scalar.return_value = mock_ext
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
assert result == mock_ext
mock_scalar.assert_called_once()
# Verify the call has the correct filters
args, kwargs = mock_scalar.call_args
stmt = args[0]
# We can't easily inspect the statement without complex sqlalchemy tricks,
# but calling it is usually enough for unit tests if we mock the result.

View File

@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
from core.moderation.input_moderation import InputModeration
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager
class TestInputModeration:
@pytest.fixture
def app_config(self):
config = MagicMock(spec=AppConfig)
config.sensitive_word_avoidance = None
return config
@pytest.fixture
def input_moderation(self):
return InputModeration()
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {"keywords": ["bad"]}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
mock_factory_cls.assert_called_once_with(
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
)
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
@patch("core.moderation.input_moderation.ModerationFactory")
@patch("core.moderation.input_moderation.TraceTask")
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
trace_manager = MagicMock(spec=TraceQueueManager)
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
trace_manager=trace_manager,
)
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
mock_trace_task.assert_called_once()
call_kwargs = mock_trace_task.call_args.kwargs
call_args = mock_trace_task.call_args.args
assert call_args[0] == TraceTaskName.MODERATION_TRACE
assert call_kwargs["message_id"] == message_id
assert call_kwargs["moderation_result"] == mock_result
assert call_kwargs["inputs"] == inputs
assert "timer" in call_kwargs
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
)
mock_factory.moderation_for_inputs.return_value = mock_result
with pytest.raises(ModerationError) as excinfo:
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert str(excinfo.value) == "Blocked content"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True,
action=ModerationAction.OVERRIDDEN,
inputs={"input_key": "overridden_value"},
query="overridden query",
)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is True
assert final_inputs == {"input_key": "overridden_value"}
assert final_query == "overridden query"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = MagicMock()
mock_result.flagged = True
mock_result.action = "NONE" # Some other action
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert flagged is True
assert final_inputs == inputs
assert final_query == query

View File

@ -0,0 +1,234 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import QueueMessageReplaceEvent
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.output_moderation import ModerationRule, OutputModeration
class TestOutputModeration:
@pytest.fixture
def mock_queue_manager(self):
return MagicMock(spec=AppQueueManager)
@pytest.fixture
def moderation_rule(self):
return ModerationRule(type="keywords", config={"keywords": "badword"})
@pytest.fixture
def output_moderation(self, mock_queue_manager, moderation_rule):
return OutputModeration(
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
)
def test_should_direct_output(self, output_moderation):
assert output_moderation.should_direct_output() is False
output_moderation.final_output = "blocked"
assert output_moderation.should_direct_output() is True
def test_get_final_output(self, output_moderation):
assert output_moderation.get_final_output() == ""
output_moderation.final_output = "blocked"
assert output_moderation.get_final_output() == "blocked"
def test_append_new_token(self, output_moderation):
with patch.object(OutputModeration, "start_thread") as mock_start:
output_moderation.append_new_token("hello")
assert output_moderation.buffer == "hello"
mock_start.assert_called_once()
output_moderation.thread = MagicMock()
output_moderation.append_new_token(" world")
assert output_moderation.buffer == "hello world"
assert mock_start.call_count == 1
def test_moderation_completion_no_flag(self, output_moderation):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output, flagged = output_moderation.moderation_completion("safe content")
assert output == "safe content"
assert flagged is False
assert output_moderation.is_final_chunk is True
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "preset"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert isinstance(args[0], QueueMessageReplaceEvent)
assert args[0].text == "preset"
assert args[1] == PublishFrom.TASK_PIPELINE
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "masked content"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked content"
def test_start_thread(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
mock_current_app._get_current_object.return_value = mock_app
with patch("threading.Thread") as mock_thread_class:
mock_thread_instance = MagicMock()
mock_thread_class.return_value = mock_thread_instance
thread = output_moderation.start_thread()
assert thread == mock_thread_instance
mock_thread_class.assert_called_once()
mock_thread_instance.start.assert_called_once()
def test_stop_thread(self, output_moderation):
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
output_moderation.thread = mock_thread
output_moderation.stop_thread()
assert output_moderation.thread_running is False
output_moderation.thread_running = True
mock_thread.is_alive.return_value = False
output_moderation.stop_thread()
assert output_moderation.thread_running is True
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_success(self, mock_factory_class, output_moderation):
mock_factory = mock_factory_class.return_value
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_outputs.return_value = mock_result
result = output_moderation.moderation("tenant", "app", "buffer")
assert result == mock_result
mock_factory_class.assert_called_once_with(
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
)
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_exception(self, mock_factory_class, output_moderation):
mock_factory_class.side_effect = Exception("error")
result = output_moderation.moderation("tenant", "app", "buffer")
assert result is None
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
# Test exit on thread_running=False
output_moderation.thread_running = False
output_moderation.worker(mock_app, 10)
# Should exit immediately
def test_worker_no_flag(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output_moderation.buffer = "safe"
output_moderation.is_final_chunk = True
# To avoid infinite loop, we'll set thread_running to False after one iteration
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
return mock_moderation.return_value
mock_moderation.side_effect = side_effect
output_moderation.worker(mock_app, 10)
assert mock_moderation.called
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
assert output_moderation.final_output == "preset"
mock_queue_manager.publish.assert_called_once()
# It breaks on DIRECT_OUTPUT
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Use side_effect to change thread_running on second call
def side_effect(*args, **kwargs):
if mock_moderation.call_count > 1:
output_moderation.thread_running = False
return None
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
mock_moderation.side_effect = side_effect
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked"
def test_worker_chunk_too_small(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("time.sleep") as mock_sleep:
# chunk_length < buffer_size and not is_final_chunk
output_moderation.buffer = "123" # length 3
output_moderation.is_final_chunk = False
def sleep_side_effect(seconds):
output_moderation.thread_running = False
mock_sleep.side_effect = sleep_side_effect
output_moderation.worker(mock_app, 10) # buffer_size 10
mock_sleep.assert_called_once_with(1)
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Return None (exception or no rule)
mock_moderation.return_value = None
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
mock_moderation.side_effect = side_effect
output_moderation.buffer = "something"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_not_called()

View File

@ -0,0 +1,160 @@
from unittest.mock import patch
import httpx
import pytest
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import (
TidbOnQdrantConfig,
TidbOnQdrantVector,
)
class TestTidbOnQdrantVectorDeleteByIds:
"""Unit tests for TidbOnQdrantVector.delete_by_ids method."""
@pytest.fixture
def vector_instance(self):
"""Create a TidbOnQdrantVector instance for testing."""
config = TidbOnQdrantConfig(
endpoint="http://localhost:6333",
api_key="test_api_key",
)
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"):
vector = TidbOnQdrantVector(
collection_name="test_collection",
group_id="test_group",
config=config,
)
return vector
def test_delete_by_ids_with_multiple_ids(self, vector_instance):
"""Test batch deletion with multiple document IDs."""
ids = ["doc1", "doc2", "doc3"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once with MatchAny filter
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Check collection name
assert call_args[1]["collection_name"] == "test_collection"
# Verify filter uses MatchAny with all IDs
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
assert len(filter_obj.must) == 1
field_condition = filter_obj.must[0]
assert field_condition.key == "metadata.doc_id"
assert isinstance(field_condition.match, rest.MatchAny)
assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"}
def test_delete_by_ids_with_single_id(self, vector_instance):
"""Test deletion with a single document ID."""
ids = ["doc1"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Verify filter uses MatchAny with single ID
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ["doc1"]
def test_delete_by_ids_with_empty_list(self, vector_instance):
"""Test deletion with empty ID list returns early without API call."""
vector_instance.delete_by_ids([])
# Verify that delete was NOT called
vector_instance._client.delete.assert_not_called()
def test_delete_by_ids_with_404_error(self, vector_instance):
"""Test that 404 errors (collection not found) are handled gracefully."""
ids = ["doc1", "doc2"]
# Mock a 404 error
error = UnexpectedResponse(
status_code=404,
reason_phrase="Not Found",
content=b"Collection not found",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should not raise an exception
vector_instance.delete_by_ids(ids)
# Verify delete was called
vector_instance._client.delete.assert_called_once()
def test_delete_by_ids_with_unexpected_error(self, vector_instance):
"""Test that non-404 errors are re-raised."""
ids = ["doc1", "doc2"]
# Mock a 500 error
error = UnexpectedResponse(
status_code=500,
reason_phrase="Internal Server Error",
content=b"Server error",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should re-raise the exception
with pytest.raises(UnexpectedResponse) as exc_info:
vector_instance.delete_by_ids(ids)
assert exc_info.value.status_code == 500
def test_delete_by_ids_with_large_batch(self, vector_instance):
"""Test deletion with a large batch of IDs."""
# Create 1000 IDs
ids = [f"doc_{i}" for i in range(1000)]
vector_instance.delete_by_ids(ids)
# Verify single delete call with all IDs
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
# Verify all 1000 IDs are in the batch
assert len(field_condition.match.any) == 1000
assert "doc_0" in field_condition.match.any
assert "doc_999" in field_condition.match.any
def test_delete_by_ids_filter_structure(self, vector_instance):
"""Test that the filter structure is correctly constructed."""
ids = ["doc1", "doc2"]
vector_instance.delete_by_ids(ids)
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
# Verify Filter structure
assert isinstance(filter_obj, rest.Filter)
assert filter_obj.must is not None
assert len(filter_obj.must) == 1
# Verify FieldCondition structure
field_condition = filter_obj.must[0]
assert isinstance(field_condition, rest.FieldCondition)
assert field_condition.key == "metadata.doc_id"
# Verify MatchAny structure
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ids

View File

@ -0,0 +1,33 @@
from unittest.mock import MagicMock, patch
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
def test_init_client_with_valid_config():
"""Test successful client initialization with valid configuration."""
config = WeaviateConfig(
endpoint="http://localhost:8080",
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
)
with patch("weaviate.connect_to_custom") as mock_connect:
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_connect.return_value = mock_client
vector = WeaviateVector(
collection_name="test_collection",
config=config,
attributes=["doc_id"],
)
assert vector._client == mock_client
mock_connect.assert_called_once()
call_kwargs = mock_connect.call_args[1]
assert call_kwargs["http_host"] == "localhost"
assert call_kwargs["http_port"] == 8080
assert call_kwargs["http_secure"] is False
assert call_kwargs["grpc_host"] == "localhost"
assert call_kwargs["grpc_port"] == 50051
assert call_kwargs["grpc_secure"] is False
assert call_kwargs["auth_credentials"] is not None

View File

@ -0,0 +1,335 @@
"""Unit tests for Weaviate vector database implementation.
Focuses on verifying that doc_type is properly handled in:
- Collection schema creation (_create_collection)
- Property migration (_ensure_properties)
- Vector search result metadata (search_by_vector)
- Full-text search result metadata (search_by_full_text)
"""
import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from core.rag.models.document import Document
class TestWeaviateVector(unittest.TestCase):
"""Tests for WeaviateVector class with focus on doc_type metadata handling."""
def setUp(self):
weaviate_vector_module._weaviate_client = None
self.config = WeaviateConfig(
endpoint="http://localhost:8080",
api_key="test-key",
batch_size=100,
)
self.collection_name = "Test_Collection_Node"
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
def tearDown(self):
weaviate_vector_module._weaviate_client = None
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def _create_weaviate_vector(self, mock_weaviate_module):
"""Helper to create a WeaviateVector instance with mocked client."""
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
return wv, mock_client
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_init(self, mock_weaviate_module):
"""Test WeaviateVector initialization stores attributes including doc_type."""
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
assert wv._collection_name == self.collection_name
assert "doc_type" in wv._attributes
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_create_collection_includes_doc_type_property(self, mock_weaviate_module, mock_dify_config, mock_redis):
"""Test that _create_collection defines doc_type in the schema properties."""
# Mock Redis
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock dify_config
mock_dify_config.WEAVIATE_TOKENIZATION = None
# Mock client
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = False
# Mock _ensure_properties to avoid side effects
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_cfg = MagicMock()
mock_cfg.properties = []
mock_col.config.get.return_value = mock_cfg
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
wv._create_collection()
# Verify collections.create was called
mock_client.collections.create.assert_called_once()
# Extract properties from the create call
call_kwargs = mock_client.collections.create.call_args
properties = call_kwargs.kwargs.get("properties")
# Verify doc_type is among the defined properties
property_names = [p.name for p in properties]
assert "doc_type" in property_names, (
f"doc_type should be in collection schema properties, got: {property_names}"
)
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module):
"""Test that _ensure_properties adds doc_type when it's missing from existing schema."""
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
# Collection exists but doc_type property is missing
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
# Simulate existing properties WITHOUT doc_type
existing_props = [
SimpleNamespace(name="text"),
SimpleNamespace(name="document_id"),
SimpleNamespace(name="doc_id"),
SimpleNamespace(name="chunk_index"),
]
mock_cfg = MagicMock()
mock_cfg.properties = existing_props
mock_col.config.get.return_value = mock_cfg
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
wv._ensure_properties()
# Verify add_property was called and includes doc_type
add_calls = mock_col.config.add_property.call_args_list
added_names = [call.args[0].name for call in add_calls]
assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
"""Test that _ensure_properties does not add doc_type when it already exists."""
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
# Simulate existing properties WITH doc_type already present
existing_props = [
SimpleNamespace(name="text"),
SimpleNamespace(name="document_id"),
SimpleNamespace(name="doc_id"),
SimpleNamespace(name="doc_type"),
SimpleNamespace(name="chunk_index"),
]
mock_cfg = MagicMock()
mock_cfg.properties = existing_props
mock_col.config.get.return_value = mock_cfg
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
wv._ensure_properties()
# No properties should be added
mock_col.config.add_property.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
"""Test that search_by_vector returns doc_type in document metadata.
This is the core bug fix verification: when doc_type is in _attributes,
it should appear in return_properties and thus be included in results.
"""
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
# Simulate search result with doc_type in properties
mock_obj = MagicMock()
mock_obj.properties = {
"text": "image content description",
"doc_id": "upload_file_id_123",
"dataset_id": "dataset_1",
"document_id": "doc_1",
"doc_hash": "hash_abc",
"doc_type": "image",
}
mock_obj.metadata.distance = 0.1
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.near_vector.return_value = mock_result
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_vector(query_vector=[0.1] * 128, top_k=1)
# Verify doc_type is in return_properties
call_kwargs = mock_col.query.near_vector.call_args
return_props = call_kwargs.kwargs.get("return_properties")
assert "doc_type" in return_props, f"doc_type should be in return_properties, got: {return_props}"
# Verify doc_type is in result metadata
assert len(docs) == 1
assert docs[0].metadata.get("doc_type") == "image"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
"""Test that search_by_full_text also returns doc_type in document metadata."""
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
# Simulate BM25 search result with doc_type
mock_obj = MagicMock()
mock_obj.properties = {
"text": "image content description",
"doc_id": "upload_file_id_456",
"doc_type": "image",
}
mock_obj.vector = {"default": [0.1] * 128}
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.bm25.return_value = mock_result
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_full_text(query="image", top_k=1)
# Verify doc_type is in return_properties
call_kwargs = mock_col.query.bm25.call_args
return_props = call_kwargs.kwargs.get("return_properties")
assert "doc_type" in return_props, (
f"doc_type should be in return_properties for BM25 search, got: {return_props}"
)
# Verify doc_type is in result metadata
assert len(docs) == 1
assert docs[0].metadata.get("doc_type") == "image"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
"""Test that add_texts includes doc_type from document metadata in stored properties."""
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
# Create a document with doc_type metadata (as produced by multimodal indexing)
doc = Document(
page_content="an image of a cat",
metadata={
"doc_id": "upload_file_123",
"doc_type": "image",
"dataset_id": "ds_1",
"document_id": "doc_1",
"doc_hash": "hash_xyz",
},
)
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
# Mock batch context manager
mock_batch = MagicMock()
mock_batch.__enter__ = MagicMock(return_value=mock_batch)
mock_batch.__exit__ = MagicMock(return_value=False)
mock_col.batch.dynamic.return_value = mock_batch
wv.add_texts(documents=[doc], embeddings=[[0.1] * 128])
# Verify batch.add_object was called with doc_type in properties
mock_batch.add_object.assert_called_once()
call_kwargs = mock_batch.add_object.call_args
stored_props = call_kwargs.kwargs.get("properties")
assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}"
class TestVectorDefaultAttributes(unittest.TestCase):
"""Tests for Vector class default attributes list."""
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
def test_default_attributes_include_doc_type(self, mock_init_vector, mock_get_embeddings):
"""Test that Vector class default attributes include doc_type."""
from core.rag.datasource.vdb.vector_factory import Vector
mock_get_embeddings.return_value = MagicMock()
mock_init_vector.return_value = MagicMock()
mock_dataset = MagicMock()
mock_dataset.index_struct_dict = None
vector = Vector(dataset=mock_dataset)
assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}"
if __name__ == "__main__":
unittest.main()

View File

@ -104,10 +104,11 @@ class TestFirecrawlApp:
def test_map_known_error(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error")
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("map error"))
mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"}))
assert app.map("https://example.com") == {}
with pytest.raises(Exception, match="map error"):
app.map("https://example.com")
mock_handle.assert_called_once()
def test_map_unknown_error_raises(self, mocker: MockerFixture):
@ -177,10 +178,11 @@ class TestFirecrawlApp:
def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error")
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl error"))
mocker.patch("httpx.get", return_value=_response(500, {"error": "server"}))
assert app.check_crawl_status("job-1") == {}
with pytest.raises(Exception, match="crawl error"):
app.check_crawl_status("job-1")
mock_handle.assert_called_once()
def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture):
@ -272,9 +274,10 @@ class TestFirecrawlApp:
def test_search_known_http_error(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error")
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("search error"))
mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"}))
assert app.search("python") == {}
with pytest.raises(Exception, match="search error"):
app.search("python")
mock_handle.assert_called_once()
def test_search_unknown_http_error(self, mocker: MockerFixture):

View File

@ -236,7 +236,8 @@ class TestParagraphIndexProcessor:
"core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [accepted, rejected]
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9

View File

@ -307,7 +307,8 @@ class TestParentChildIndexProcessor:
"core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [ok_result, low_result]
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {})
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].page_content == "keep"

View File

@ -262,7 +262,8 @@ class TestQAIndexProcessor:
with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve:
mock_retrieve.return_value = [result_ok, result_low]
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].page_content == "accepted"

View File

@ -25,6 +25,7 @@ from core.app.app_config.entities import ModelConfig as WorkflowModelConfig
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.rag.data_post_processor.data_post_processor import WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -4686,7 +4687,10 @@ class TestSingleAndMultipleRetrieveCoverage:
extra={"dataset_name": "Ext", "title": "Ext"},
)
app = Flask(__name__)
weights = {"vector_setting": {}}
weights: WeightsDict = {
"vector_setting": {"vector_weight": 0.5, "embedding_provider_name": "", "embedding_model_name": ""},
"keyword_setting": {"keyword_weight": 0.5},
}
def fake_multiple_thread(**kwargs):
if kwargs["query"]:

View File

@ -0,0 +1,677 @@
from __future__ import annotations
import dataclasses
import json
from collections.abc import Sequence
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
_HumanInputFormEntityImpl,
_HumanInputFormRecipientEntityImpl,
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from dify_graph.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
from libs.datetime_utils import naive_utc_now
from models.human_input import HumanInputFormRecipient, RecipientType
@pytest.fixture(autouse=True)
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
class _FakeSelect:
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
payload: dict[str, Any] = {
"form_content": "hi",
"inputs": [],
"user_actions": [{"id": "submit", "title": "Submit"}],
"rendered_content": "<p>hi</p>",
}
if include_expiration_time:
payload["expiration_time"] = naive_utc_now()
return json.dumps(payload, default=str)
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str | None
class _FakeScalarResult:
def __init__(self, obj: Any):
self._obj = obj
def first(self) -> Any:
if isinstance(self._obj, list):
return self._obj[0] if self._obj else None
return self._obj
def all(self) -> list[Any]:
if self._obj is None:
return []
if isinstance(self._obj, list):
return list(self._obj)
return [self._obj]
class _FakeExecuteResult:
def __init__(self, rows: Sequence[tuple[Any, ...]]):
self._rows = list(rows)
def all(self) -> list[tuple[Any, ...]]:
return list(self._rows)
class _FakeSession:
def __init__(
self,
*,
scalars_result: Any = None,
scalars_results: list[Any] | None = None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
execute_rows: Sequence[tuple[Any, ...]] = (),
):
if scalars_results is not None:
self._scalars_queue = list(scalars_results)
else:
self._scalars_queue = [scalars_result]
self._forms = forms or {}
self._recipients = recipients or {}
self._execute_rows = list(execute_rows)
self.added: list[Any] = []
def scalars(self, _query: Any) -> _FakeScalarResult:
if self._scalars_queue:
value = self._scalars_queue.pop(0)
else:
value = None
return _FakeScalarResult(value)
def execute(self, _stmt: Any) -> _FakeExecuteResult:
return _FakeExecuteResult(self._execute_rows)
def get(self, model_cls: Any, obj_id: str) -> Any:
name = getattr(model_cls, "__name__", "")
if name == "HumanInputForm":
return self._forms.get(obj_id)
if name == "HumanInputFormRecipient":
return self._recipients.get(obj_id)
return None
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: Sequence[Any]) -> None:
self.added.extend(list(objs))
def flush(self) -> None:
# Simulate DB default population for attributes referenced in entity wrappers.
for obj in self.added:
if hasattr(obj, "id") and obj.id in (None, ""):
obj.id = f"gen-{len(str(self.added))}"
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
if obj.recipient_type == RecipientType.CONSOLE:
obj.access_token = "token-console"
elif obj.recipient_type == RecipientType.BACKSTAGE:
obj.access_token = "token-backstage"
else:
obj.access_token = "token-webapp"
def refresh(self, _obj: Any) -> None:
return None
def begin(self) -> _FakeSession:
return self
def __enter__(self) -> _FakeSession:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
class _SessionFactoryStub:
def __init__(self, session: _FakeSession):
self._session = session
def create_session(self) -> _FakeSession:
return self._session
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
def test_recipient_entity_token_raises_when_missing() -> None:
recipient = SimpleNamespace(id="r1", access_token=None)
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
with pytest.raises(AssertionError, match="access_token should not be None"):
_ = entity.token
def test_recipient_entity_id_and_token_success() -> None:
recipient = SimpleNamespace(id="r1", access_token="tok")
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
assert entity.id == "r1"
assert entity.token == "tok"
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
webapp = _DummyRecipient(
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
assert entity.web_app_token == "ctok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
assert entity.web_app_token == "wtok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.web_app_token is None
def test_form_entity_submitted_data_parsed() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
submitted_data='{"a": 1}',
submitted_at=naive_utc_now(),
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.submitted is True
assert entity.submitted_data == {"a": 1}
assert entity.rendered_content == "<p>x</p>"
assert entity.selected_action_id is None
assert entity.status == HumanInputFormStatus.WAITING
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
expiration = naive_utc_now()
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=False),
rendered_content="<p>x</p>",
expiration_time=expiration,
submitted_data='{"k": "v"}',
)
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
assert record.definition.expiration_time == expiration
assert record.submitted_data == {"k": "v"}
assert record.submitted is False
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
id=f"{payload.TYPE}-{len(created)}",
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
access_token="tok",
)
created.append(recipient)
return recipient
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
form_id="f",
delivery_id="d",
members=[
_WorkspaceMemberInfo(user_id="u1", email=""),
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
],
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
)
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
def test_query_workspace_members_by_ids_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
assert rows == [
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
]
def test_query_all_workspace_members_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_all_workspace_members(session=session)
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
def test_repository_init_sets_tenant_id() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._tenant_id == "tenant"
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
result = repo._delivery_method_to_model(
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
)
assert result.delivery.id == "del-1"
assert result.delivery.form_id == "form-1"
assert len(result.recipients) == 1
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
called: dict[str, Any] = {}
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
called.update(
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
)
return ["r"]
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
subject="s",
body="b",
)
)
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
assert result.recipients == ["r"]
assert called["delivery_id"] == "del-1"
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr(
repo,
"_query_all_workspace_members",
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
)
assert recipients == ["ok"]
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
assert restrict_to_user_ids == ["u1"]
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
)
assert recipients == ["ok"]
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo.get_form("run", "node") is None
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
)
session = _FakeSession(scalars_results=[form, [recipient]])
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
entity = repo.get_form("run", "node")
assert entity is not None
assert entity.id == "f1"
assert entity.recipients[0].id == "r1"
assert entity.recipients[0].token == "tok"
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
session = _FakeSession()
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
form_config = HumanInputNodeData(
title="Title",
delivery_methods=[],
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
)
params = FormCreateParams(
app_id="app",
workflow_execution_id="run",
node_id="node",
form_config=form_config,
rendered_content="<p>hello</p>",
delivery_methods=[WebAppDeliveryMethod()],
display_in_ui=True,
resolved_default_values={},
form_kind=HumanInputFormKind.RUNTIME,
console_recipient_required=True,
console_creator_account_id="acc-1",
backstage_recipient_required=True,
)
entity = repo.create_form(params)
assert entity.id == "form-id"
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
# Console token should take precedence when console recipient is present.
assert entity.web_app_token == "token-console"
assert len(entity.recipients) == 3
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
recipient = SimpleNamespace(form=None)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
def test_submission_repository_init_no_args() -> None:
repo = HumanInputFormSubmissionRepository()
assert isinstance(repo, HumanInputFormSubmissionRepository)
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = SimpleNamespace(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
form=form,
)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_token("tok")
assert record is not None
assert record.access_token == "tok"
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
assert record is not None
assert record.recipient_id == "r1"
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
missing_session = _FakeSession(forms={})
_patch_session_factory(monkeypatch, missing_session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_submitted(
form_id="missing",
recipient_id=None,
selected_action_id="a",
form_data={},
submission_user_id=None,
submission_end_user_id=None,
)
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"k": "v"},
submission_user_id="u",
submission_end_user_id="eu",
)
assert form.status == HumanInputFormStatus.SUBMITTED
assert form.submitted_at == fixed_now
assert record.submitted_data == {"k": "v"}
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.TIMEOUT,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
assert record.status == HumanInputFormStatus.TIMEOUT
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.SUBMITTED,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form already submitted"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
selected_action_id="a",
submitted_data="{}",
submission_user_id="u",
submission_end_user_id="eu",
completed_by_recipient_id="r",
status=HumanInputFormStatus.WAITING,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
assert form.status == HumanInputFormStatus.EXPIRED
assert form.selected_action_id is None
assert form.submitted_data is None
assert form.submission_user_id is None
assert form.submission_end_user_id is None
assert form.completed_by_recipient_id is None
assert record.status == HumanInputFormStatus.EXPIRED
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)

View File

@ -1,84 +1,291 @@
from datetime import datetime
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from models import Account, CreatorUserRole, EndUser, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
session_factory = MagicMock(spec=sessionmaker)
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
session_factory.return_value.__enter__.return_value = session
return session_factory
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
@pytest.fixture
def mock_engine():
"""Mock SQLAlchemy Engine."""
return MagicMock(spec=Engine)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = MagicMock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = MagicMock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
outputs={"output1": "result1"},
status=WorkflowExecutionStatus.SUCCEEDED,
error_message="",
total_tokens=100,
total_steps=5,
exceptions_count=0,
started_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()
class TestSQLAlchemyWorkflowExecutionRepository:
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
app_id = "test_app_id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
)
assert repo._session_factory == mock_session_factory
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
def test_init_with_engine(self, mock_engine, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_engine,
user=mock_account,
app_id="test_app_id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
assert repo._session_factory.kw["bind"] == mock_engine
def test_init_invalid_session_factory(self, mock_account):
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowExecutionRepository(
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
)
def test_init_no_tenant_id(self, mock_session_factory):
user = MagicMock(spec=Account)
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
)
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
)
assert repo._tenant_id == mock_end_user.tenant_id
assert repo._creator_user_role == CreatorUserRole.END_USER
def test_to_domain_model(self, mock_session_factory, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
db_model = MagicMock(spec=WorkflowRun)
db_model.id = str(uuid4())
db_model.workflow_id = str(uuid4())
db_model.type = "workflow"
db_model.version = "1.0"
db_model.inputs_dict = {"in": "val"}
db_model.outputs_dict = {"out": "val"}
db_model.graph_dict = {"nodes": []}
db_model.status = "succeeded"
db_model.error = "some error"
db_model.total_tokens = 50
db_model.total_steps = 3
db_model.exceptions_count = 1
db_model.created_at = datetime.now(UTC)
db_model.finished_at = datetime.now(UTC)
domain_model = repo._to_domain_model(db_model)
assert domain_model.id_ == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
assert domain_model.inputs == db_model.inputs_dict
assert domain_model.error_message == "some error"
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Make elapsed time deterministic to avoid flaky tests
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.id == sample_workflow_execution.id_
assert db_model.tenant_id == repo._tenant_id
assert db_model.app_id == "test_app"
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
assert db_model.status == sample_workflow_execution.status.value
assert db_model.total_tokens == sample_workflow_execution.total_tokens
assert db_model.elapsed_time == 10.0
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Test with empty/None fields
sample_workflow_execution.graph = None
sample_workflow_execution.inputs = None
sample_workflow_execution.outputs = None
sample_workflow_execution.error_message = None
sample_workflow_execution.finished_at = None
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.graph is None
assert db_model.inputs is None
assert db_model.outputs is None
assert db_model.error is None
assert db_model.elapsed_time == 0
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=None,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
db_model = repo._to_db_model(sample_workflow_execution)
assert not hasattr(db_model, "app_id") or db_model.app_id is None
assert db_model.tenant_id == repo._tenant_id
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
# Test triggered_from missing
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(sample_workflow_execution)
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(sample_workflow_execution)
repo._creator_user_id = "some_id"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(sample_workflow_execution)
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
session = mock_session_factory.return_value.__enter__.return_value
session.merge.assert_called_once()
session.commit.assert_called_once()
# Check cache
assert sample_workflow_execution.id_ in repo._execution_cache
cached_model = repo._execution_cache[sample_workflow_execution.id_]
assert cached_model.id == sample_workflow_execution.id_
def test_save_uses_execution_started_at_when_record_does_not_exist(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
sample_workflow_execution.started_at = started_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = None
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution_id = sample_workflow_execution.id_
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repo._tenant_id
existing_run.created_at = existing_created_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = existing_run
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@ -0,0 +1,772 @@
from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock
import psycopg2.errors
import pytest
from sqlalchemy import Engine, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
_deterministic_json_dump,
_filter_by_offload_type,
_find_first,
_replace_or_append_offload,
)
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.enums import ExecutionOffLoadType
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
return user
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
user = Mock(spec=EndUser)
user.id = user_id
user.tenant_id = tenant_id
return user
def _execution(
*,
execution_id: str = "exec-id",
node_execution_id: str = "node-exec-id",
workflow_run_id: str = "run-id",
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
inputs: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
) -> WorkflowNodeExecution:
return WorkflowNodeExecution(
id=execution_id,
node_execution_id=node_execution_id,
workflow_id="workflow-id",
workflow_execution_id=workflow_run_id,
index=1,
predecessor_node_id=None,
node_id="node-id",
node_type=NodeType.LLM,
title="Title",
inputs=inputs,
outputs=outputs,
process_data=process_data,
status=status,
error=None,
elapsed_time=1.0,
metadata=metadata,
created_at=datetime.now(UTC),
finished_at=None,
)
class _SessionCtx:
def __init__(self, session: Any):
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
return None
def _session_factory(session: Any) -> sessionmaker:
factory = Mock(spec=sessionmaker)
factory.return_value = _SessionCtx(session)
return factory
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
engine: Engine = create_engine("sqlite:///:memory:")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=engine,
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
sm = Mock(spec=sessionmaker)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=sm,
user=_mock_end_user(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._creator_user_role.value == "end_user"
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
session_factory=object(),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
user = _mock_account()
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=user,
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
created: dict[str, Any] = {}
class FakeTruncator:
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
created.update(
{
"max_size_bytes": max_size_bytes,
"array_element_limit": array_element_limit,
"string_length_limit": string_length_limit,
}
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
FakeTruncator,
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
_ = repo._create_truncator()
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
assert _find_first([], lambda _: True) is None
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
assert len(replaced) == 2
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
# Happy path: deterministic json dump should be sorted
db_model = repo._to_db_model(execution)
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(execution)
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution()
db_model = repo._to_db_model(execution)
assert db_model.app_id == "app"
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(execution)
repo._creator_user_id = "user"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(execution)
def test_is_duplicate_key_error_and_regenerate_id(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
assert repo._is_duplicate_key_error(duplicate_error) is True
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
execution = _execution(execution_id="old-id")
db_model = WorkflowNodeExecutionModel()
db_model.id = "old-id"
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
caplog.set_level(logging.WARNING)
repo._regenerate_id_on_duplicate(execution, db_model)
assert execution.id == "new-id"
assert db_model.id == "new-id"
assert any("Duplicate key conflict" in r.message for r in caplog.records)
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id1"
db_model.node_execution_id = "node1"
db_model.foo = "bar" # type: ignore[attr-defined]
db_model.__dict__["_private"] = "x"
existing = SimpleNamespace()
session.get.return_value = existing
repo._persist_to_database(db_model)
assert existing.foo == "bar"
session.add.assert_not_called()
assert repo._node_execution_cache["node1"] is db_model
session.reset_mock()
session.get.return_value = None
repo._node_execution_cache.clear()
repo._persist_to_database(db_model)
session.add.assert_called_once_with(db_model)
assert repo._node_execution_cache["node1"] is db_model
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return value, False
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
uploaded: dict[str, Any] = {}
class FakeFileService:
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
return SimpleNamespace(id="file-id", key="file-key")
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return {"truncated": True}, True
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
assert result is not None
assert result.truncated_value == {"truncated": True}
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
assert result.offload.file_id == "file-id"
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"trunc": "i"})
db_model.process_data = json.dumps({"trunc": "p"})
db_model.outputs = json.dumps({"trunc": "o"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = json.dumps({"total_tokens": 3})
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
off_in.file = SimpleNamespace(key="k-in")
off_out.file = SimpleNamespace(key="k-out")
off_proc.file = SimpleNamespace(key="k-proc")
db_model.offload_data = [off_out, off_in, off_proc]
def fake_load(key: str) -> bytes:
return json.dumps({"full": key}).encode()
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"full": "k-in"}
assert domain.outputs == {"full": "k-out"}
assert domain.process_data == {"full": "k-proc"}
assert domain.get_truncated_inputs() == {"trunc": "i"}
assert domain.get_truncated_outputs() == {"trunc": "o"}
assert domain.get_truncated_process_data() == {"trunc": "p"}
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"i": 1})
db_model.process_data = json.dumps({"p": 2})
db_model.outputs = json.dumps({"o": 3})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = "{}"
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
db_model.offload_data = []
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"i": 1}
assert domain.outputs == {"o": 3}
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeConverter:
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
return {"wrapped": values["a"]}
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
FakeConverter,
)
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
id="id",
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
inputs=None,
outputs=None,
process_data=None,
)
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
trunc_result = SimpleNamespace(
truncated_value={"trunc": True},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
)
monkeypatch.setattr(
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
# Inputs should be truncated, outputs/process_data encoded directly
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.inputs) == {"trunc": True}
assert json.loads(db_model.outputs) == {"b": 2}
assert json.loads(db_model.process_data) == {"c": 3}
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
assert execution.get_truncated_inputs() == {"trunc": True}
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
existing = SimpleNamespace(
id="id",
offload_data=[],
inputs=None,
outputs=None,
process_data=None,
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = existing
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
if values == {"b": 2}:
return SimpleNamespace(
truncated_value={"b": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
)
if values == {"c": 3}:
return SimpleNamespace(
truncated_value={"c": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
)
return None
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.outputs) == {"b": "trunc"}
assert json.loads(db_model.process_data) == {"c": "trunc"}
assert execution.get_truncated_outputs() == {"b": "trunc"}
assert execution.get_truncated_process_data() == {"c": "trunc"}
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = None
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1})
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
repo.save_execution_data(execution)
merged = session.merge.call_args.args[0]
assert merged.inputs == '{"a": 1}'
def test_save_retries_duplicate_and_logs_non_duplicate(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(execution_id="id")
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
other_error = IntegrityError("other", params=None, orig=None)
calls = {"n": 0}
def persist(_db_model: Any) -> None:
calls["n"] += 1
if calls["n"] == 1:
raise duplicate_error
monkeypatch.setattr(repo, "_persist_to_database", persist)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
repo.save(execution)
assert execution.id == "new-id"
assert repo._node_execution_cache[execution.node_execution_id] is not None
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
with pytest.raises(IntegrityError):
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
def test_save_logs_and_reraises_on_unexpected_error(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
with pytest.raises(RuntimeError, match="boom"):
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def __init__(self) -> None:
self.where_calls = 0
self.order_by_args: tuple[Any, ...] | None = None
def where(self, *_args: Any) -> FakeStmt:
self.where_calls += 1
return self
def order_by(self, *args: Any) -> FakeStmt:
self.order_by_args = args
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
model1 = SimpleNamespace(node_execution_id="n1")
model2 = SimpleNamespace(node_execution_id=None)
session = MagicMock()
session.scalars.return_value.all.return_value = [model1, model2]
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
db_models = repo.get_db_models_by_workflow_run("run", order)
assert db_models == [model1, model2]
assert repo._node_execution_cache["n1"] is model1
assert stmt.order_by_args is not None
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def where(self, *_args: Any) -> FakeStmt:
return self
def order_by(self, *args: Any) -> FakeStmt:
self.args = args # type: ignore[attr-defined]
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
session = MagicMock()
session.scalars.return_value.all.return_value = []
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
class FakeExecutor:
def __enter__(self) -> FakeExecutor:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
assert timeout == 30
return list(map(func, items))
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
lambda max_workers: FakeExecutor(),
)
result = repo.get_by_workflow_run("run", order_config=None)
assert result == ["domain:db1", "domain:db2"]

View File

@ -0,0 +1,137 @@
import json
from unittest.mock import patch
from core.schemas.registry import SchemaRegistry
class TestSchemaRegistry:
def test_initialization(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
registry = SchemaRegistry(str(base_dir))
assert registry.base_dir == base_dir
assert registry.versions == {}
assert registry.metadata == {}
def test_default_registry_singleton(self):
registry1 = SchemaRegistry.default_registry()
registry2 = SchemaRegistry.default_registry()
assert registry1 is registry2
assert isinstance(registry1, SchemaRegistry)
def test_load_all_versions_non_existent_dir(self, tmp_path):
base_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(base_dir))
registry.load_all_versions()
assert registry.versions == {}
def test_load_all_versions_filtering(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
(base_dir / "not_a_version_dir").mkdir()
(base_dir / "v1").mkdir()
(base_dir / "some_file.txt").write_text("content")
registry = SchemaRegistry(str(base_dir))
with patch.object(registry, "_load_version_dir") as mock_load:
registry.load_all_versions()
mock_load.assert_called_once()
assert mock_load.call_args[0][0] == "v1"
def test_load_version_dir_filtering(self, tmp_path):
version_dir = tmp_path / "v1"
version_dir.mkdir()
(version_dir / "schema1.json").write_text("{}")
(version_dir / "not_a_schema.txt").write_text("content")
registry = SchemaRegistry(str(tmp_path))
with patch.object(registry, "_load_schema") as mock_load:
registry._load_version_dir("v1", version_dir)
mock_load.assert_called_once()
assert mock_load.call_args[0][1] == "schema1"
def test_load_version_dir_non_existent(self, tmp_path):
version_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(tmp_path))
registry._load_version_dir("v1", version_dir)
assert "v1" not in registry.versions
def test_load_schema_success(self, tmp_path):
schema_path = tmp_path / "test.json"
schema_content = {"title": "Test Schema", "description": "A test schema"}
schema_path.write_text(json.dumps(schema_content))
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "test", schema_path)
assert registry.versions["v1"]["test"] == schema_content
uri = "https://dify.ai/schemas/v1/test.json"
assert registry.metadata[uri]["title"] == "Test Schema"
assert registry.metadata[uri]["version"] == "v1"
def test_load_schema_invalid_json(self, tmp_path, caplog):
schema_path = tmp_path / "invalid.json"
schema_path.write_text("invalid json")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "invalid", schema_path)
assert "Failed to load schema v1/invalid" in caplog.text
def test_load_schema_os_error(self, tmp_path, caplog):
schema_path = tmp_path / "error.json"
schema_path.write_text("{}")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
with patch("builtins.open", side_effect=OSError("Read error")):
registry._load_schema("v1", "error", schema_path)
assert "Failed to load schema v1/error" in caplog.text
def test_get_schema(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"type": "object"}}}
# Valid URI
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
# Invalid URI
assert registry.get_schema("invalid-uri") is None
# Missing version
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
def test_list_versions(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v2": {}, "v1": {}}
assert registry.list_versions() == ["v1", "v2"]
def test_list_schemas(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"b": {}, "a": {}}}
assert registry.list_schemas("v1") == ["a", "b"]
assert registry.list_schemas("v2") == []
def test_get_all_schemas_for_version(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
results = registry.get_all_schemas_for_version("v1")
assert len(results) == 1
assert results[0]["name"] == "test"
assert results[0]["label"] == "Test Label"
assert results[0]["schema"] == {"title": "Test Label"}
# Default label if title missing
registry.versions["v1"]["no_title"] = {}
results = registry.get_all_schemas_for_version("v1")
item = next(r for r in results if r["name"] == "no_title")
assert item["label"] == "no_title"
# Empty if version missing
assert registry.get_all_schemas_for_version("v2") == []

View File

@ -0,0 +1,80 @@
from unittest.mock import MagicMock, patch
from core.schemas.registry import SchemaRegistry
from core.schemas.schema_manager import SchemaManager
def test_init_with_provided_registry():
mock_registry = MagicMock(spec=SchemaRegistry)
manager = SchemaManager(registry=mock_registry)
assert manager.registry == mock_registry
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
def test_init_with_default_registry(mock_default_registry):
mock_registry = MagicMock(spec=SchemaRegistry)
mock_default_registry.return_value = mock_registry
manager = SchemaManager()
mock_default_registry.assert_called_once()
assert manager.registry == mock_registry
def test_get_all_schema_definitions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
manager = SchemaManager(registry=mock_registry)
result = manager.get_all_schema_definitions(version="v2")
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
assert result == expected_definitions
def test_get_schema_by_name_success():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_schema = {"type": "object"}
mock_registry.get_schema.return_value = mock_schema
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("my_schema", version="v1")
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
mock_registry.get_schema.assert_called_once_with(expected_uri)
assert result == {"name": "my_schema", "schema": mock_schema}
def test_get_schema_by_name_not_found():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_registry.get_schema.return_value = None
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("non_existent", version="v1")
assert result is None
def test_list_available_schemas():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_schemas = ["schema1", "schema2"]
mock_registry.list_schemas.return_value = expected_schemas
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_schemas(version="v1")
mock_registry.list_schemas.assert_called_once_with("v1")
assert result == expected_schemas
def test_list_available_versions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_versions = ["v1", "v2"]
mock_registry.list_versions.return_value = expected_versions
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_versions()
mock_registry.list_versions.assert_called_once()
assert result == expected_versions

View File

@ -1,32 +1,34 @@
from unittest.mock import Mock, PropertyMock, patch
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
def mock_provider_entity():
mock_entity = Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
provider_credential_schema = Mock()
type(provider_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
model_credential_schema = Mock()
type(model_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
def test__to_model_settings(mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
tenant_id="tenant_id",
@ -63,18 +65,18 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
load_balancing_model_configs[0].id = "id1"
load_balancing_model_configs[1].id = "id2"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
with patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
return_value={"openai_api_key": "fake_key"},
):
provider_manager = ProviderManager()
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
@ -87,7 +89,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
def test__to_model_settings_only_one_lb(mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
@ -113,18 +115,18 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
]
load_balancing_model_configs[0].id = "id1"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
with patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
return_value={"openai_api_key": "fake_key"},
):
provider_manager = ProviderManager()
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
@ -135,7 +137,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
def test__to_model_settings_lb_disabled(mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
tenant_id="tenant_id",
@ -170,18 +172,18 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
load_balancing_model_configs[0].id = "id1"
load_balancing_model_configs[1].id = "id2"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
with patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
return_value={"openai_api_key": "fake_key"},
):
provider_manager = ProviderManager()
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
@ -190,3 +192,39 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0
def test_get_default_model_uses_first_available_active_model():
mock_session = Mock()
mock_session.scalar.return_value = None
provider_configurations = Mock()
provider_configurations.get_models.return_value = [
Mock(model="gpt-3.5-turbo", provider=Mock(provider="openai")),
Mock(model="gpt-4", provider=Mock(provider="openai")),
]
manager = ProviderManager()
with (
patch("core.provider_manager.db.session", mock_session),
patch.object(manager, "get_configurations", return_value=provider_configurations),
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
):
mock_factory_cls.return_value.get_provider_schema.return_value = Mock(
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
supported_model_types=[ModelType.LLM],
)
result = manager.get_default_model("tenant-id", ModelType.LLM)
assert result is not None
assert result.model == "gpt-3.5-turbo"
assert result.provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
mock_session.add.assert_called_once()
saved_default_model = mock_session.add.call_args.args[0]
assert saved_default_model.model_name == "gpt-3.5-turbo"
assert saved_default_model.provider_name == "openai"
mock_session.commit.assert_called_once()

View File

@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode
from dify_graph.nodes.http_request import HttpRequestNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
@ -68,6 +68,8 @@ class MockNodeMixin:
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
# LLM-like nodes now require an http_client; provide a mock by default for tests.
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
if isinstance(self, (LLMNode, QuestionClassifierNode)):
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
# Ensure TemplateTransformNode receives a renderer now required by constructor
if isinstance(self, TemplateTransformNode):

View File

@ -4,9 +4,7 @@ from __future__ import annotations
import pytest
import dify_graph.graph_engine.response_coordinator.session as response_session_module
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType
from dify_graph.graph_engine.response_coordinator import RESPONSE_SESSION_NODE_TYPES
from dify_graph.graph_engine.response_coordinator.session import ResponseSession
from dify_graph.nodes.base.template import Template, TextSegment
@ -35,28 +33,14 @@ class DummyNodeWithoutStreamingTemplate:
self.state = NodeState.UNKNOWN
def test_response_session_from_node_rejects_node_types_outside_allowlist() -> None:
"""Unsupported node types are rejected even if they expose a template."""
def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None:
"""Session creation depends on the streaming-template contract rather than node type."""
node = DummyResponseNode(
node_id="llm-node",
node_type=BuiltinNodeTypes.LLM,
template=Template(segments=[TextSegment(text="hello")]),
)
with pytest.raises(TypeError, match="RESPONSE_SESSION_NODE_TYPES"):
ResponseSession.from_node(node)
def test_response_session_from_node_supports_downstream_allowlist_extension(monkeypatch) -> None:
"""Downstream applications can extend the supported node-type list."""
node = DummyResponseNode(
node_id="llm-node",
node_type=BuiltinNodeTypes.LLM,
template=Template(segments=[TextSegment(text="hello")]),
)
extended_node_types = [*RESPONSE_SESSION_NODE_TYPES, BuiltinNodeTypes.LLM]
monkeypatch.setattr(response_session_module, "RESPONSE_SESSION_NODE_TYPES", extended_node_types)
session = ResponseSession.from_node(node)
assert session.node_id == "llm-node"

View File

@ -0,0 +1,145 @@
import queue
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
from dify_graph.graph_engine.worker import Worker
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
worker = Worker(
ready_queue=InMemoryReadyQueue(),
event_queue=queue.Queue(),
graph=MagicMock(),
layers=[],
)
node = SimpleNamespace(
execution_id="exec-1",
id="node-1",
node_type=BuiltinNodeTypes.LLM,
)
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
assert event.start_at == fixed_time
assert event.finished_at == fixed_time
assert event.error == "boom"
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
assert event.node_run_result.error == "boom"
assert event.node_run_result.error_type == "RuntimeError"
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
failure_time = start_at + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeNode:
execution_id = "exec-1"
id = "node-1"
node_type = BuiltinNodeTypes.LLM
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="LLM",
start_at=start_at,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"node-1": FakeNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["node-1"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 1:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == start_at
assert fallback_event.finished_at == failure_time
assert fallback_event.error == "queue boom"
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None:
parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
child_start = parent_start + timedelta(seconds=3)
failure_time = parent_start + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeIterationNode:
execution_id = "iteration-exec"
id = "iteration-node"
node_type = BuiltinNodeTypes.ITERATION
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="Iteration",
start_at=parent_start,
)
yield NodeRunStartedEvent(
id="child-exec",
node_id="child-node",
node_type=BuiltinNodeTypes.LLM,
node_title="LLM",
start_at=child_start,
in_iteration_id=self.id,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["iteration-node"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 2:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == parent_start
assert fallback_event.finished_at == failure_time

View File

@ -14,3 +14,64 @@ def test_render_body_template_replaces_variable_values():
result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool)
assert result == "Hello World https://example.com"
def test_render_markdown_body_renders_markdown_to_html():
rendered = EmailDeliveryConfig.render_markdown_body("**Bold** and [link](https://example.com)")
assert "<strong>Bold</strong>" in rendered
assert '<a href="https://example.com">link</a>' in rendered
def test_render_markdown_body_sanitizes_unsafe_html():
rendered = EmailDeliveryConfig.render_markdown_body(
'<script>alert("xss")</script><a href="javascript:alert(1)" onclick="alert(2)">Click</a>'
)
assert "<script" not in rendered
assert "<a" not in rendered
assert "onclick" not in rendered
assert "javascript:" not in rendered
assert "Click" in rendered
def test_render_markdown_body_sanitizes_markdown_link_with_javascript_href():
rendered = EmailDeliveryConfig.render_markdown_body("[bad](javascript:alert(1)) and [ok](https://example.com)")
assert "javascript:" not in rendered
assert "<a>bad</a>" in rendered
assert '<a href="https://example.com">ok</a>' in rendered
def test_render_markdown_body_does_not_allow_raw_html_tags():
rendered = EmailDeliveryConfig.render_markdown_body("<b>raw html</b> and **markdown**")
assert "<b>" not in rendered
assert "raw html" in rendered
assert "<strong>markdown</strong>" in rendered
def test_render_markdown_body_supports_table_syntax():
rendered = EmailDeliveryConfig.render_markdown_body("| h1 | h2 |\n| --- | ---: |\n| v1 | v2 |")
assert "<table>" in rendered
assert "<thead>" in rendered
assert "<tbody>" in rendered
assert 'align="right"' in rendered
assert "style=" not in rendered
def test_sanitize_subject_removes_crlf():
sanitized = EmailDeliveryConfig.sanitize_subject("Notice\r\nBCC:attacker@example.com")
assert "\r" not in sanitized
assert "\n" not in sanitized
assert sanitized == "Notice BCC:attacker@example.com"
def test_sanitize_subject_removes_html_tags():
sanitized = EmailDeliveryConfig.sanitize_subject("<b>Alert</b><img src=x onerror=1>")
assert "<" not in sanitized
assert ">" not in sanitized
assert sanitized == "Alert"

View File

@ -0,0 +1,63 @@
import time
from contextlib import nullcontext
from datetime import UTC, datetime
import pytest
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.graph_events import NodeRunSucceededEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from dify_graph.nodes.iteration.iteration_node import IterationNode
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Parallel Iteration",
iterator_selector=["start", "items"],
output_selector=["iteration", "output"],
is_parallel=True,
parallel_nums=2,
error_handle_mode=ErrorHandleMode.TERMINATED,
)
node._capture_execution_context = lambda: nullcontext()
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
return (
0.1 + (index * 0.1),
[
NodeRunSucceededEvent(
id=f"exec-{index}",
node_id=f"llm-{index}",
node_type=BuiltinNodeTypes.LLM,
start_at=datetime.now(UTC).replace(tzinfo=None),
),
],
f"output-{item}",
{},
LLMUsage.empty_usage(),
)
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
outputs: list[object] = []
iter_run_map: dict[str, float] = {}
usage_accumulator = [LLMUsage.empty_usage()]
generator = node._execute_parallel_iterations(
iterator_list_value=["a", "b"],
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
for _ in generator:
# Simulate a slow consumer replaying buffered events.
time.sleep(0.02)
assert outputs == ["output-a", "output-b"]
assert iter_run_map["0"] == pytest.approx(0.1)
assert iter_run_map["1"] == pytest.approx(0.2)

View File

@ -1,18 +1,26 @@
"""Tests for llm_utils module, specifically multimodal content handling."""
"""Tests for llm_utils module, specifically multimodal content handling and prompt message construction."""
import string
from unittest import mock
from unittest.mock import patch
import pytest
from core.model_manager import ModelInstance
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
from dify_graph.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.nodes.llm import llm_utils
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
from dify_graph.nodes.llm.exc import NoPromptFoundError
from dify_graph.nodes.llm.llm_utils import (
_truncate_multimodal_content,
build_context,
restore_multimodal_content_in_messages,
)
from dify_graph.runtime import VariablePool
class TestTruncateMultimodalContent:
@ -50,7 +58,6 @@ class TestTruncateMultimodalContent:
assert isinstance(result_content, ImagePromptMessageContent)
assert result_content.base64_data == ""
assert result_content.url == ""
# file_ref should be preserved
assert result_content.file_ref == "local:test-file-id"
def test_truncates_base64_when_no_file_ref(self):
@ -70,7 +77,6 @@ class TestTruncateMultimodalContent:
assert isinstance(result.content, list)
result_content = result.content[0]
assert isinstance(result_content, ImagePromptMessageContent)
# Should be truncated with marker
assert "...[TRUNCATED]..." in result_content.base64_data
assert len(result_content.base64_data) < len(long_base64)
@ -89,9 +95,7 @@ class TestTruncateMultimodalContent:
assert isinstance(result.content, list)
assert len(result.content) == 2
# Text content unchanged
assert result.content[0].data == "Hello!"
# Image content base64 cleared
assert result.content[1].base64_data == ""
@ -100,8 +104,6 @@ class TestBuildContext:
def test_excludes_system_messages(self):
"""System messages should be excluded from context."""
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Hello!"),
@ -109,7 +111,6 @@ class TestBuildContext:
context = build_context(messages, "Hi there!")
# Should have user message + assistant response, no system message
assert len(context) == 2
assert context[0].content == "Hello!"
assert context[1].content == "Hi there!"
@ -140,7 +141,6 @@ class TestBuildContext:
messages = [UserPromptMessage(content="What's the weather in Beijing?")]
# Create trace with tool call and result
generation_data = LLMGenerationData(
text="The weather in Beijing is sunny, 25°C.",
reasoning_contents=[],
@ -183,7 +183,6 @@ class TestBuildContext:
accumulated_response = "Let me check the weather.The weather in Beijing is sunny, 25°C."
context = build_context(messages, accumulated_response, generation_data)
# Should have: user message + assistant with tool_call + tool result + final assistant
assert len(context) == 4
assert context[0].content == "What's the weather in Beijing?"
assert isinstance(context[1], AssistantPromptMessage)
@ -223,7 +222,6 @@ class TestBuildContext:
finish_reason="stop",
files=[],
trace=[
# First model call with two tool calls
LLMTraceSegment(
type="model",
duration=0.5,
@ -237,7 +235,6 @@ class TestBuildContext:
],
),
),
# First tool result
LLMTraceSegment(
type="tool",
duration=0.2,
@ -249,7 +246,6 @@ class TestBuildContext:
output="Sunny, 25°C",
),
),
# Second tool result
LLMTraceSegment(
type="tool",
duration=0.2,
@ -267,7 +263,6 @@ class TestBuildContext:
accumulated_response = "I'll check both cities.Beijing is sunny at 25°C, Shanghai is cloudy at 22°C."
context = build_context(messages, accumulated_response, generation_data)
# Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant
assert len(context) == 5
assert context[0].content == "Compare weather in Beijing and Shanghai"
assert isinstance(context[1], AssistantPromptMessage)
@ -304,12 +299,11 @@ class TestBuildContext:
usage=LLMUsage.empty_usage(),
finish_reason="stop",
files=[],
trace=[], # Empty trace
trace=[],
)
context = build_context(messages, "Hi there!", generation_data)
# Should fallback to simple context
assert len(context) == 2
assert context[0].content == "Hello!"
assert context[1].content == "Hi there!"
@ -321,7 +315,6 @@ class TestRestoreMultimodalContentInMessages:
@patch("dify_graph.file.file_manager.restore_multimodal_content")
def test_restores_multimodal_content(self, mock_restore):
"""Should restore multimodal content in messages."""
# Setup mock
restored_content = ImagePromptMessageContent(
format="png",
base64_data="restored-base64",
@ -330,7 +323,6 @@ class TestRestoreMultimodalContentInMessages:
)
mock_restore.return_value = restored_content
# Create message with truncated content
truncated_content = ImagePromptMessageContent(
format="png",
base64_data="",
@ -363,3 +355,98 @@ class TestRestoreMultimodalContentInMessages:
assert len(result) == 1
assert result[0].content[0].data == "Hello!"
def _fetch_prompt_messages_with_mocked_content(content):
variable_pool = VariablePool.empty()
model_instance = mock.MagicMock(spec=ModelInstance)
prompt_template = [
LLMNodeChatModelMessage(
text="You are a classifier.",
role=PromptMessageRole.SYSTEM,
edition_type="basic",
)
]
with (
mock.patch(
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
return_value=mock.MagicMock(features=[]),
),
mock.patch(
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
return_value=[SystemPromptMessage(content=content)],
),
mock.patch(
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
return_value=[],
),
):
return llm_utils.fetch_prompt_messages(
sys_query=None,
sys_files=[],
context=None,
memory=None,
model_instance=model_instance,
prompt_template=prompt_template,
stop=["END"],
memory_config=None,
vision_enabled=False,
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
variable_pool=variable_pool,
jinja2_variables=[],
template_renderer=None,
)
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
with pytest.raises(NoPromptFoundError):
_fetch_prompt_messages_with_mocked_content(
[
ImagePromptMessageContent(
format="url",
url="https://example.com/image.png",
mime_type="image/png",
),
]
)
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
[
TextPromptMessageContent(data="You are a classifier."),
ImagePromptMessageContent(
format="url",
url="https://example.com/image.png",
mime_type="image/png",
),
]
)
assert stop == ["END"]
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
[
TextPromptMessageContent(data="You are"),
TextPromptMessageContent(data=" a classifier."),
ImagePromptMessageContent(
format="url",
url="https://example.com/image.png",
mime_type="image/png",
),
]
)
assert stop == ["END"]
assert prompt_messages == [
SystemPromptMessage(
content=[
TextPromptMessageContent(data="You are"),
TextPromptMessageContent(data=" a classifier."),
]
)
]

View File

@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
VisionConfigOptions,
)
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@ -107,6 +107,7 @@ def llm_node(
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -121,6 +122,7 @@ def llm_node(
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node
@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
messages = [
LLMNodeChatModelMessage(
text="",
jinja2_text="Hello, {{ name }}",
role=PromptMessageRole.USER,
edition_type="jinja2",
)
]
result = llm_node.handle_list_messages(
messages=messages,
context=None,
jinja2_variables=[],
variable_pool=llm_node.graph_runtime_state.variable_pool,
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
template_renderer=llm_node._template_renderer,
)
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
llm_node._template_renderer.render_jinja2.assert_called_once_with(
template="Hello, {{ name }}",
inputs={},
)
def test_handle_memory_completion_mode_uses_prompt_message_interface():
memory = mock.MagicMock(spec=MockTokenBufferMemory)
memory.get_history_prompt_messages.return_value = [
@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
window=MemoryConfig.WindowConfig(enabled=True, size=3),
)
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = _handle_memory_completion_mode(
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = llm_utils.handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node, mock_file_saver

View File

@ -1,5 +1,14 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from dify_graph.model_runtime.entities import ImagePromptMessageContent
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.question_classifier import (
QuestionClassifierNode,
QuestionClassifierNodeData,
)
from tests.workflow_test_utils import build_test_graph_init_params
def test_init_question_classifier_node_data():
@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
assert node_data.vision.enabled == False
assert node_data.vision.configs.variable_selector == ["sys", "files"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
node_data = QuestionClassifierNodeData.model_validate(
{
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
}
)
template_renderer = MagicMock(spec=TemplateRenderer)
node = QuestionClassifierNode(
id="node-id",
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
graph_init_params=build_test_graph_init_params(
workflow_id="workflow-id",
graph_config={},
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
),
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(),
http_client=MagicMock(spec=HttpClientProtocol),
llm_file_saver=MagicMock(),
template_renderer=template_renderer,
)
fetch_prompt_messages = MagicMock(return_value=([], None))
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
fetch_prompt_messages,
)
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
)
node._calculate_rest_token(
node_data=node_data,
query="hello",
model_instance=MagicMock(stop=(), parameters={}),
context="",
)
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer

View File

@ -0,0 +1,63 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from tests.workflow_test_utils import build_test_graph_init_params
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
init_params = build_test_graph_init_params(
graph_config=graph_config,
user_from="account",
invoke_from="debugger",
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable(user_id="user", files=[]),
user_inputs={"payload": "value"},
),
start_at=0.0,
)
return init_params, runtime_state
def _build_node_config() -> NodeConfigDict:
return NodeConfigDictAdapter.validate_python(
{
"id": "node-1",
"data": {
"type": TRIGGER_PLUGIN_NODE_TYPE,
"title": "Trigger Event",
"plugin_id": "plugin-id",
"provider_id": "provider-id",
"event_name": "event-name",
"subscription_id": "subscription-id",
"plugin_unique_identifier": "plugin-unique-identifier",
"event_parameters": {},
},
}
)
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
id="node-1",
config=_build_node_config(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
"provider_id": "provider-id",
"event_name": "event-name",
"plugin_unique_identifier": "plugin-unique-identifier",
}

View File

@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
assert executor.is_execution_error(RuntimeError("boom")) is False
class TestDefaultLLMTemplateRenderer:
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
renderer = node_factory.DefaultLLMTemplateRenderer()
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
monkeypatch.setattr(
node_factory.CodeExecutor,
"execute_workflow_code_template",
execute_workflow_code_template,
)
result = renderer.render_jinja2(
template="Hello {{ name }}",
inputs={"name": "world"},
)
assert result == "hello world"
execute_workflow_code_template.assert_called_once_with(
language=CodeLanguage.JINJA2,
code="Hello {{ name }}",
inputs={"name": "world"},
)
class TestDifyNodeFactoryInit:
def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"})
@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
http_request_config = sentinel.http_request_config
credentials_provider = sentinel.credentials_provider
model_factory = sentinel.model_factory
llm_template_renderer = sentinel.llm_template_renderer
with (
patch.object(
@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
"build_http_request_config",
return_value=http_request_config,
),
patch.object(
node_factory,
"DefaultLLMTemplateRenderer",
return_value=llm_template_renderer,
) as llm_renderer_factory,
patch.object(
node_factory,
"build_dify_model_access",
@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
build_dify_model_access.assert_called_once_with("tenant-id")
renderer_factory.assert_called_once()
llm_renderer_factory.assert_called_once()
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
assert factory.graph_init_params is graph_init_params
assert factory.graph_runtime_state is graph_runtime_state
assert factory._dify_context is dify_context
assert factory._template_renderer is template_renderer
assert factory._llm_template_renderer is llm_template_renderer
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
assert factory._http_request_config is http_request_config
assert factory._llm_credentials_provider is credentials_provider
@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
factory._code_executor = sentinel.code_executor
factory._code_limits = sentinel.code_limits
factory._template_renderer = sentinel.template_renderer
factory._llm_template_renderer = sentinel.llm_template_renderer
factory._template_transform_max_output_length = 2048
factory._http_request_http_client = sentinel.http_client
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
[
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
(
BuiltinNodeTypes.LLM,
"LLMNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(
BuiltinNodeTypes.QUESTION_CLASSIFIER,
"QuestionClassifierNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
],
)