mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 23:48:04 +08:00
Merge main HEAD (segment 5) into sandboxed-agent-rebase
Resolve 83 conflicts: 10 backend, 62 frontend, 11 config/lock files. Preserve sandbox/agent/collaboration features while adopting main's UI refactorings (Dialog/AlertDialog/Popover), model provider updates, and enterprise features. Made-with: Cursor
This commit is contained in:
@ -234,6 +234,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
captured: dict[str, object] = {}
|
||||
prefill_calls: list[object] = []
|
||||
var_loader = SimpleNamespace(loader="draft")
|
||||
workflow = SimpleNamespace(id="workflow-id")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
|
||||
@ -260,8 +261,8 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
def __init__(self, session):
|
||||
_ = session
|
||||
|
||||
def prefill_conversation_variable_default_values(self, workflow):
|
||||
prefill_calls.append(workflow)
|
||||
def prefill_conversation_variable_default_values(self, workflow, user_id):
|
||||
prefill_calls.append((workflow, user_id))
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
|
||||
|
||||
@ -273,7 +274,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
|
||||
result = generator.single_iteration_generate(
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
|
||||
workflow=SimpleNamespace(id="workflow-id"),
|
||||
workflow=workflow,
|
||||
node_id="node-1",
|
||||
user=SimpleNamespace(id="user-id"),
|
||||
args={"inputs": {"foo": "bar"}},
|
||||
@ -281,7 +282,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert prefill_calls
|
||||
assert prefill_calls == [(workflow, "user-id")]
|
||||
assert captured["variable_loader"] is var_loader
|
||||
assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1"
|
||||
|
||||
@ -291,6 +292,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
captured: dict[str, object] = {}
|
||||
prefill_calls: list[object] = []
|
||||
var_loader = SimpleNamespace(loader="draft")
|
||||
workflow = SimpleNamespace(id="workflow-id")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
|
||||
@ -317,8 +319,8 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
def __init__(self, session):
|
||||
_ = session
|
||||
|
||||
def prefill_conversation_variable_default_values(self, workflow):
|
||||
prefill_calls.append(workflow)
|
||||
def prefill_conversation_variable_default_values(self, workflow, user_id):
|
||||
prefill_calls.append((workflow, user_id))
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
|
||||
|
||||
@ -330,7 +332,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
|
||||
result = generator.single_loop_generate(
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
|
||||
workflow=SimpleNamespace(id="workflow-id"),
|
||||
workflow=workflow,
|
||||
node_id="node-2",
|
||||
user=SimpleNamespace(id="user-id"),
|
||||
args=SimpleNamespace(inputs={"foo": "bar"}),
|
||||
@ -338,7 +340,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert prefill_calls
|
||||
assert prefill_calls == [(workflow, "user-id")]
|
||||
assert captured["variable_loader"] is var_loader
|
||||
assert captured["application_generate_entity"].single_loop_run.node_id == "node-2"
|
||||
|
||||
|
||||
@ -44,11 +44,22 @@ class TestAgentChatAppGenerateResponseConverterBlocking:
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
@ -107,11 +118,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "summary",
|
||||
"extra": "ignored",
|
||||
}
|
||||
@ -151,11 +173,22 @@ class TestAgentChatAppGenerateResponseConverterStream:
|
||||
assert "usage" not in metadata
|
||||
assert metadata["retriever_resources"] == [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "summary",
|
||||
}
|
||||
]
|
||||
|
||||
@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_prefers_event_finished_at(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Finished timestamps should come from the event, not delayed queue processing time."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.common.workflow_response_converter.naive_utc_now",
|
||||
lambda: delayed_processing_time,
|
||||
)
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.CODE,
|
||||
node_execution_id="node-exec-1",
|
||||
start_at=start_at,
|
||||
finished_at=finished_at,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.elapsed_time == 2.0
|
||||
assert response.data.finished_at == int(finished_at.timestamp())
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
@ -38,11 +38,22 @@ class TestCompletionAppGenerateResponseConverter:
|
||||
metadata = {
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset 1",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"data_source_type": "file",
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "abc1234",
|
||||
"content": "c",
|
||||
"page": 5,
|
||||
"title": "Citation Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "sum",
|
||||
"extra": "x",
|
||||
}
|
||||
@ -66,7 +77,12 @@ class TestCompletionAppGenerateResponseConverter:
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
assert result["metadata"]["retriever_resources"][0]["dataset_id"] == "dataset-1"
|
||||
assert result["metadata"]["retriever_resources"][0]["document_id"] == "document-1"
|
||||
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
|
||||
assert result["metadata"]["retriever_resources"][0]["data_source_type"] == "file"
|
||||
assert result["metadata"]["retriever_resources"][0]["segment_position"] == 3
|
||||
assert result["metadata"]["retriever_resources"][0]["index_node_hash"] == "abc1234"
|
||||
assert "extra" not in result["metadata"]["retriever_resources"][0]
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_not_dict(self):
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.task_pipeline import message_cycle_manager
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation():
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id="user-id",
|
||||
from_account_id=None,
|
||||
)
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.workflow.layers.persistence import (
|
||||
PersistenceWorkflowInfo,
|
||||
WorkflowPersistenceLayer,
|
||||
_NodeRuntimeSnapshot,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_layer() -> WorkflowPersistenceLayer:
|
||||
application_generate_entity = Mock()
|
||||
application_generate_entity.inputs = {}
|
||||
|
||||
return WorkflowPersistenceLayer(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version="1",
|
||||
graph_data={},
|
||||
),
|
||||
workflow_execution_repository=Mock(),
|
||||
workflow_node_execution_repository=Mock(),
|
||||
)
|
||||
|
||||
|
||||
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
layer = _build_layer()
|
||||
node_execution = Mock()
|
||||
node_execution.id = "node-exec-1"
|
||||
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
node_execution.update_from_mapping = Mock()
|
||||
|
||||
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
|
||||
node_id="node-id",
|
||||
title="LLM",
|
||||
predecessor_node_id=None,
|
||||
iteration_id="iter-1",
|
||||
loop_id=None,
|
||||
created_at=node_execution.created_at,
|
||||
)
|
||||
|
||||
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
|
||||
|
||||
layer._update_node_execution(
|
||||
node_execution,
|
||||
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event_finished_at,
|
||||
)
|
||||
|
||||
assert node_execution.finished_at == event_finished_at
|
||||
assert node_execution.elapsed_time == 2.0
|
||||
@ -166,6 +166,7 @@ class TestDatasourceFileManager:
|
||||
# Setup
|
||||
mock_guess_ext.return_value = None # Cannot guess
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
|
||||
@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ProviderCredentialSchema,
|
||||
ProviderEntity,
|
||||
)
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
@ -409,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() -
|
||||
|
||||
configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session)
|
||||
|
||||
assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value
|
||||
assert existing_record.preferred_provider_type == ProviderType.SYSTEM
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
|
||||
id="lb-base",
|
||||
name="LB Base",
|
||||
credentials={},
|
||||
credential_source_type="provider",
|
||||
credential_source_type=CredentialSourceType.PROVIDER,
|
||||
)
|
||||
],
|
||||
),
|
||||
@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
|
||||
id="lb-custom",
|
||||
name="LB Custom",
|
||||
credentials={},
|
||||
credential_source_type="custom_model",
|
||||
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
|
||||
)
|
||||
],
|
||||
),
|
||||
@ -734,7 +735,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No
|
||||
def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(is_valid=False)
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred")
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
@ -743,6 +744,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id == "existing-cred"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_create_provider_credential_auto_activates_when_no_active_credential() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None)
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}):
|
||||
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record):
|
||||
with patch("core.entities.provider_configuration.ProviderCredentialsCache"):
|
||||
with patch.object(ProviderConfiguration, "switch_preferred_provider_type"):
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id is not None
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
@ -807,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None:
|
||||
configuration._update_load_balancing_configs_with_credential(
|
||||
credential_id="cred-1",
|
||||
credential_record=credential_record,
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@ -825,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non
|
||||
configuration._update_load_balancing_configs_with_credential(
|
||||
credential_id="cred-1",
|
||||
credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"),
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
@ -801,6 +801,27 @@ class TestAuthOrchestration:
|
||||
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com")
|
||||
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]
|
||||
|
||||
def test_build_protected_resource_metadata_discovery_urls_with_relative_hint(self):
|
||||
urls = build_protected_resource_metadata_discovery_urls(
|
||||
"/.well-known/oauth-protected-resource/tenant/mcp",
|
||||
"https://api.example.com/tenant/mcp",
|
||||
)
|
||||
assert urls == [
|
||||
"https://api.example.com/.well-known/oauth-protected-resource/tenant/mcp",
|
||||
"https://api.example.com/.well-known/oauth-protected-resource",
|
||||
]
|
||||
|
||||
def test_build_protected_resource_metadata_discovery_urls_ignores_scheme_less_hint(self):
|
||||
urls = build_protected_resource_metadata_discovery_urls(
|
||||
"/openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource/tenant/mcp",
|
||||
"https://openapi-mcp.cn-hangzhou.aliyuncs.com/tenant/mcp",
|
||||
)
|
||||
|
||||
assert urls == [
|
||||
"https://openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource/tenant/mcp",
|
||||
"https://openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource",
|
||||
]
|
||||
|
||||
def test_build_oauth_authorization_server_metadata_discovery_urls(self):
|
||||
# Case 1: with auth_server_url
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(
|
||||
|
||||
181
api/tests/unit_tests/core/moderation/api/test_api.py
Normal file
181
api/tests/unit_tests/core/moderation/api/test_api.py
Normal file
@ -0,0 +1,181 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
|
||||
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
|
||||
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
class TestApiModeration:
|
||||
@pytest.fixture
|
||||
def api_config(self):
|
||||
return {
|
||||
"inputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"outputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"api_based_extension_id": "test-extension-id",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def api_moderation(self, api_config):
|
||||
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
|
||||
|
||||
def test_moderation_input_params(self):
|
||||
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.inputs == {"key": "val"}
|
||||
assert params.query == "test query"
|
||||
|
||||
# Test defaults
|
||||
params_default = ModerationInputParams()
|
||||
assert params_default.app_id == ""
|
||||
assert params_default.inputs == {}
|
||||
assert params_default.query == ""
|
||||
|
||||
def test_moderation_output_params(self):
|
||||
params = ModerationOutputParams(app_id="app-1", text="test text")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.text == "test text"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ModerationOutputParams()
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_success(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
|
||||
def test_validate_config_missing_extension_id(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": True},
|
||||
}
|
||||
with pytest.raises(ValueError, match="api_based_extension_id is required"):
|
||||
ApiModeration.validate_config("test-tenant-id", config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
|
||||
|
||||
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
|
||||
|
||||
assert isinstance(result, ModerationInputsResult)
|
||||
assert result.flagged is True
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == "Blocked by API"
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_INPUT,
|
||||
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
|
||||
)
|
||||
|
||||
def test_moderation_for_inputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": False},
|
||||
"outputs_config": {"enabled": True},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_inputs(inputs={}, query="")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == ""
|
||||
|
||||
def test_moderation_for_inputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_inputs({}, "")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
|
||||
|
||||
result = api_moderation.moderation_for_outputs(text="hello world")
|
||||
|
||||
assert isinstance(result, ModerationOutputsResult)
|
||||
assert result.flagged is False
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
|
||||
)
|
||||
|
||||
def test_moderation_for_outputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": False},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_outputs(text="test")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
|
||||
def test_moderation_for_outputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_outputs("test")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
@patch("core.moderation.api.api.decrypt_token")
|
||||
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
|
||||
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_ext.api_endpoint = "http://api.test"
|
||||
mock_ext.api_key = "encrypted-key"
|
||||
mock_get_ext.return_value = mock_ext
|
||||
|
||||
mock_decrypt.return_value = "decrypted-key"
|
||||
|
||||
mock_requestor = MagicMock()
|
||||
mock_requestor.request.return_value = {"flagged": True}
|
||||
mock_requestor_cls.return_value = mock_requestor
|
||||
|
||||
params = {"some": "params"}
|
||||
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
assert result == {"flagged": True}
|
||||
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
|
||||
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
|
||||
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
def test_get_config_by_requestor_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
|
||||
mock_get_ext.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.db.session.scalar")
|
||||
def test_get_api_based_extension(self, mock_scalar):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_scalar.return_value = mock_ext
|
||||
|
||||
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
|
||||
|
||||
assert result == mock_ext
|
||||
mock_scalar.assert_called_once()
|
||||
# Verify the call has the correct filters
|
||||
args, kwargs = mock_scalar.call_args
|
||||
stmt = args[0]
|
||||
# We can't easily inspect the statement without complex sqlalchemy tricks,
|
||||
# but calling it is usually enough for unit tests if we mock the result.
|
||||
207
api/tests/unit_tests/core/moderation/test_input_moderation.py
Normal file
207
api/tests/unit_tests/core/moderation/test_input_moderation.py
Normal file
@ -0,0 +1,207 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
|
||||
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class TestInputModeration:
|
||||
@pytest.fixture
|
||||
def app_config(self):
|
||||
config = MagicMock(spec=AppConfig)
|
||||
config.sensitive_word_avoidance = None
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def input_moderation(self):
|
||||
return InputModeration()
|
||||
|
||||
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {"keywords": ["bad"]}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
mock_factory_cls.assert_called_once_with(
|
||||
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
|
||||
)
|
||||
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
@patch("core.moderation.input_moderation.TraceTask")
|
||||
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
trace_manager = MagicMock(spec=TraceQueueManager)
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
|
||||
mock_trace_task.assert_called_once()
|
||||
call_kwargs = mock_trace_task.call_args.kwargs
|
||||
call_args = mock_trace_task.call_args.args
|
||||
assert call_args[0] == TraceTaskName.MODERATION_TRACE
|
||||
assert call_kwargs["message_id"] == message_id
|
||||
assert call_kwargs["moderation_result"] == mock_result
|
||||
assert call_kwargs["inputs"] == inputs
|
||||
assert "timer" in call_kwargs
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
with pytest.raises(ModerationError) as excinfo:
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == "Blocked content"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True,
|
||||
action=ModerationAction.OVERRIDDEN,
|
||||
inputs={"input_key": "overridden_value"},
|
||||
query="overridden query",
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == {"input_key": "overridden_value"}
|
||||
assert final_query == "overridden query"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = MagicMock()
|
||||
mock_result.flagged = True
|
||||
mock_result.action = "NONE" # Some other action
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
234
api/tests/unit_tests/core/moderation/test_output_moderation.py
Normal file
234
api/tests/unit_tests/core/moderation/test_output_moderation.py
Normal file
@ -0,0 +1,234 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
|
||||
|
||||
class TestOutputModeration:
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
return MagicMock(spec=AppQueueManager)
|
||||
|
||||
@pytest.fixture
|
||||
def moderation_rule(self):
|
||||
return ModerationRule(type="keywords", config={"keywords": "badword"})
|
||||
|
||||
@pytest.fixture
|
||||
def output_moderation(self, mock_queue_manager, moderation_rule):
|
||||
return OutputModeration(
|
||||
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
|
||||
)
|
||||
|
||||
def test_should_direct_output(self, output_moderation):
|
||||
assert output_moderation.should_direct_output() is False
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.should_direct_output() is True
|
||||
|
||||
def test_get_final_output(self, output_moderation):
|
||||
assert output_moderation.get_final_output() == ""
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.get_final_output() == "blocked"
|
||||
|
||||
def test_append_new_token(self, output_moderation):
|
||||
with patch.object(OutputModeration, "start_thread") as mock_start:
|
||||
output_moderation.append_new_token("hello")
|
||||
assert output_moderation.buffer == "hello"
|
||||
mock_start.assert_called_once()
|
||||
|
||||
output_moderation.thread = MagicMock()
|
||||
output_moderation.append_new_token(" world")
|
||||
assert output_moderation.buffer == "hello world"
|
||||
assert mock_start.call_count == 1
|
||||
|
||||
def test_moderation_completion_no_flag(self, output_moderation):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("safe content")
|
||||
|
||||
assert output == "safe content"
|
||||
assert flagged is False
|
||||
assert output_moderation.is_final_chunk is True
|
||||
|
||||
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "preset"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert isinstance(args[0], QueueMessageReplaceEvent)
|
||||
assert args[0].text == "preset"
|
||||
assert args[1] == PublishFrom.TASK_PIPELINE
|
||||
|
||||
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "masked content"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked content"
|
||||
|
||||
def test_start_thread(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
|
||||
mock_current_app._get_current_object.return_value = mock_app
|
||||
with patch("threading.Thread") as mock_thread_class:
|
||||
mock_thread_instance = MagicMock()
|
||||
mock_thread_class.return_value = mock_thread_instance
|
||||
|
||||
thread = output_moderation.start_thread()
|
||||
|
||||
assert thread == mock_thread_instance
|
||||
mock_thread_class.assert_called_once()
|
||||
mock_thread_instance.start.assert_called_once()
|
||||
|
||||
def test_stop_thread(self, output_moderation):
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.is_alive.return_value = True
|
||||
output_moderation.thread = mock_thread
|
||||
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is False
|
||||
|
||||
output_moderation.thread_running = True
|
||||
mock_thread.is_alive.return_value = False
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is True
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_success(self, mock_factory_class, output_moderation):
|
||||
mock_factory = mock_factory_class.return_value
|
||||
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_outputs.return_value = mock_result
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
|
||||
assert result == mock_result
|
||||
mock_factory_class.assert_called_once_with(
|
||||
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
|
||||
)
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_exception(self, mock_factory_class, output_moderation):
|
||||
mock_factory_class.side_effect = Exception("error")
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
assert result is None
|
||||
|
||||
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
# Test exit on thread_running=False
|
||||
output_moderation.thread_running = False
|
||||
output_moderation.worker(mock_app, 10)
|
||||
# Should exit immediately
|
||||
|
||||
def test_worker_no_flag(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output_moderation.buffer = "safe"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
# To avoid infinite loop, we'll set thread_running to False after one iteration
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
return mock_moderation.return_value
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert mock_moderation.called
|
||||
|
||||
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert output_moderation.final_output == "preset"
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
# It breaks on DIRECT_OUTPUT
|
||||
|
||||
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Use side_effect to change thread_running on second call
|
||||
def side_effect(*args, **kwargs):
|
||||
if mock_moderation.call_count > 1:
|
||||
output_moderation.thread_running = False
|
||||
return None
|
||||
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked"
|
||||
|
||||
def test_worker_chunk_too_small(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# chunk_length < buffer_size and not is_final_chunk
|
||||
output_moderation.buffer = "123" # length 3
|
||||
output_moderation.is_final_chunk = False
|
||||
|
||||
def sleep_side_effect(seconds):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_sleep.side_effect = sleep_side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10) # buffer_size 10
|
||||
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Return None (exception or no rule)
|
||||
mock_moderation.return_value = None
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "something"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
@ -0,0 +1,160 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import (
|
||||
TidbOnQdrantConfig,
|
||||
TidbOnQdrantVector,
|
||||
)
|
||||
|
||||
|
||||
class TestTidbOnQdrantVectorDeleteByIds:
|
||||
"""Unit tests for TidbOnQdrantVector.delete_by_ids method."""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_instance(self):
|
||||
"""Create a TidbOnQdrantVector instance for testing."""
|
||||
config = TidbOnQdrantConfig(
|
||||
endpoint="http://localhost:6333",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
|
||||
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"):
|
||||
vector = TidbOnQdrantVector(
|
||||
collection_name="test_collection",
|
||||
group_id="test_group",
|
||||
config=config,
|
||||
)
|
||||
return vector
|
||||
|
||||
def test_delete_by_ids_with_multiple_ids(self, vector_instance):
|
||||
"""Test batch deletion with multiple document IDs."""
|
||||
ids = ["doc1", "doc2", "doc3"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify that delete was called once with MatchAny filter
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
# Check collection name
|
||||
assert call_args[1]["collection_name"] == "test_collection"
|
||||
|
||||
# Verify filter uses MatchAny with all IDs
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
assert len(filter_obj.must) == 1
|
||||
|
||||
field_condition = filter_obj.must[0]
|
||||
assert field_condition.key == "metadata.doc_id"
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"}
|
||||
|
||||
def test_delete_by_ids_with_single_id(self, vector_instance):
|
||||
"""Test deletion with a single document ID."""
|
||||
ids = ["doc1"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify that delete was called once
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
# Verify filter uses MatchAny with single ID
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ["doc1"]
|
||||
|
||||
def test_delete_by_ids_with_empty_list(self, vector_instance):
|
||||
"""Test deletion with empty ID list returns early without API call."""
|
||||
vector_instance.delete_by_ids([])
|
||||
|
||||
# Verify that delete was NOT called
|
||||
vector_instance._client.delete.assert_not_called()
|
||||
|
||||
def test_delete_by_ids_with_404_error(self, vector_instance):
|
||||
"""Test that 404 errors (collection not found) are handled gracefully."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
# Mock a 404 error
|
||||
error = UnexpectedResponse(
|
||||
status_code=404,
|
||||
reason_phrase="Not Found",
|
||||
content=b"Collection not found",
|
||||
headers=httpx.Headers(),
|
||||
)
|
||||
vector_instance._client.delete.side_effect = error
|
||||
|
||||
# Should not raise an exception
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify delete was called
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
|
||||
def test_delete_by_ids_with_unexpected_error(self, vector_instance):
|
||||
"""Test that non-404 errors are re-raised."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
# Mock a 500 error
|
||||
error = UnexpectedResponse(
|
||||
status_code=500,
|
||||
reason_phrase="Internal Server Error",
|
||||
content=b"Server error",
|
||||
headers=httpx.Headers(),
|
||||
)
|
||||
vector_instance._client.delete.side_effect = error
|
||||
|
||||
# Should re-raise the exception
|
||||
with pytest.raises(UnexpectedResponse) as exc_info:
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
def test_delete_by_ids_with_large_batch(self, vector_instance):
|
||||
"""Test deletion with a large batch of IDs."""
|
||||
# Create 1000 IDs
|
||||
ids = [f"doc_{i}" for i in range(1000)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify single delete call with all IDs
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
|
||||
# Verify all 1000 IDs are in the batch
|
||||
assert len(field_condition.match.any) == 1000
|
||||
assert "doc_0" in field_condition.match.any
|
||||
assert "doc_999" in field_condition.match.any
|
||||
|
||||
def test_delete_by_ids_filter_structure(self, vector_instance):
|
||||
"""Test that the filter structure is correctly constructed."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
|
||||
# Verify Filter structure
|
||||
assert isinstance(filter_obj, rest.Filter)
|
||||
assert filter_obj.must is not None
|
||||
assert len(filter_obj.must) == 1
|
||||
|
||||
# Verify FieldCondition structure
|
||||
field_condition = filter_obj.must[0]
|
||||
assert isinstance(field_condition, rest.FieldCondition)
|
||||
assert field_condition.key == "metadata.doc_id"
|
||||
|
||||
# Verify MatchAny structure
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ids
|
||||
@ -0,0 +1,33 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
|
||||
|
||||
def test_init_client_with_valid_config():
|
||||
"""Test successful client initialization with valid configuration."""
|
||||
config = WeaviateConfig(
|
||||
endpoint="http://localhost:8080",
|
||||
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
|
||||
)
|
||||
|
||||
with patch("weaviate.connect_to_custom") as mock_connect:
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_connect.return_value = mock_client
|
||||
|
||||
vector = WeaviateVector(
|
||||
collection_name="test_collection",
|
||||
config=config,
|
||||
attributes=["doc_id"],
|
||||
)
|
||||
|
||||
assert vector._client == mock_client
|
||||
mock_connect.assert_called_once()
|
||||
call_kwargs = mock_connect.call_args[1]
|
||||
assert call_kwargs["http_host"] == "localhost"
|
||||
assert call_kwargs["http_port"] == 8080
|
||||
assert call_kwargs["http_secure"] is False
|
||||
assert call_kwargs["grpc_host"] == "localhost"
|
||||
assert call_kwargs["grpc_port"] == 50051
|
||||
assert call_kwargs["grpc_secure"] is False
|
||||
assert call_kwargs["auth_credentials"] is not None
|
||||
@ -0,0 +1,335 @@
|
||||
"""Unit tests for Weaviate vector database implementation.
|
||||
|
||||
Focuses on verifying that doc_type is properly handled in:
|
||||
- Collection schema creation (_create_collection)
|
||||
- Property migration (_ensure_properties)
|
||||
- Vector search result metadata (search_by_vector)
|
||||
- Full-text search result metadata (search_by_full_text)
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class TestWeaviateVector(unittest.TestCase):
|
||||
"""Tests for WeaviateVector class with focus on doc_type metadata handling."""
|
||||
|
||||
def setUp(self):
|
||||
weaviate_vector_module._weaviate_client = None
|
||||
self.config = WeaviateConfig(
|
||||
endpoint="http://localhost:8080",
|
||||
api_key="test-key",
|
||||
batch_size=100,
|
||||
)
|
||||
self.collection_name = "Test_Collection_Node"
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
|
||||
def tearDown(self):
|
||||
weaviate_vector_module._weaviate_client = None
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def _create_weaviate_vector(self, mock_weaviate_module):
|
||||
"""Helper to create a WeaviateVector instance with mocked client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
return wv, mock_client
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_init(self, mock_weaviate_module):
|
||||
"""Test WeaviateVector initialization stores attributes including doc_type."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
assert wv._collection_name == self.collection_name
|
||||
assert "doc_type" in wv._attributes
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config")
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_create_collection_includes_doc_type_property(self, mock_weaviate_module, mock_dify_config, mock_redis):
|
||||
"""Test that _create_collection defines doc_type in the schema properties."""
|
||||
# Mock Redis
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock dify_config
|
||||
mock_dify_config.WEAVIATE_TOKENIZATION = None
|
||||
|
||||
# Mock client
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
mock_client.collections.exists.return_value = False
|
||||
|
||||
# Mock _ensure_properties to avoid side effects
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = []
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
wv._create_collection()
|
||||
|
||||
# Verify collections.create was called
|
||||
mock_client.collections.create.assert_called_once()
|
||||
|
||||
# Extract properties from the create call
|
||||
call_kwargs = mock_client.collections.create.call_args
|
||||
properties = call_kwargs.kwargs.get("properties")
|
||||
|
||||
# Verify doc_type is among the defined properties
|
||||
property_names = [p.name for p in properties]
|
||||
assert "doc_type" in property_names, (
|
||||
f"doc_type should be in collection schema properties, got: {property_names}"
|
||||
)
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module):
|
||||
"""Test that _ensure_properties adds doc_type when it's missing from existing schema."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
# Collection exists but doc_type property is missing
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# Simulate existing properties WITHOUT doc_type
|
||||
existing_props = [
|
||||
SimpleNamespace(name="text"),
|
||||
SimpleNamespace(name="document_id"),
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
]
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = existing_props
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
wv._ensure_properties()
|
||||
|
||||
# Verify add_property was called and includes doc_type
|
||||
add_calls = mock_col.config.add_property.call_args_list
|
||||
added_names = [call.args[0].name for call in add_calls]
|
||||
assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}"
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
|
||||
"""Test that _ensure_properties does not add doc_type when it already exists."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# Simulate existing properties WITH doc_type already present
|
||||
existing_props = [
|
||||
SimpleNamespace(name="text"),
|
||||
SimpleNamespace(name="document_id"),
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="doc_type"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
]
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = existing_props
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
wv._ensure_properties()
|
||||
|
||||
# No properties should be added
|
||||
mock_col.config.add_property.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
"""Test that search_by_vector returns doc_type in document metadata.
|
||||
|
||||
This is the core bug fix verification: when doc_type is in _attributes,
|
||||
it should appear in return_properties and thus be included in results.
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# Simulate search result with doc_type in properties
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {
|
||||
"text": "image content description",
|
||||
"doc_id": "upload_file_id_123",
|
||||
"dataset_id": "dataset_1",
|
||||
"document_id": "doc_1",
|
||||
"doc_hash": "hash_abc",
|
||||
"doc_type": "image",
|
||||
}
|
||||
mock_obj.metadata.distance = 0.1
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
mock_col.query.near_vector.return_value = mock_result
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_vector(query_vector=[0.1] * 128, top_k=1)
|
||||
|
||||
# Verify doc_type is in return_properties
|
||||
call_kwargs = mock_col.query.near_vector.call_args
|
||||
return_props = call_kwargs.kwargs.get("return_properties")
|
||||
assert "doc_type" in return_props, f"doc_type should be in return_properties, got: {return_props}"
|
||||
|
||||
# Verify doc_type is in result metadata
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata.get("doc_type") == "image"
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
"""Test that search_by_full_text also returns doc_type in document metadata."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# Simulate BM25 search result with doc_type
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {
|
||||
"text": "image content description",
|
||||
"doc_id": "upload_file_id_456",
|
||||
"doc_type": "image",
|
||||
}
|
||||
mock_obj.vector = {"default": [0.1] * 128}
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
mock_col.query.bm25.return_value = mock_result
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_full_text(query="image", top_k=1)
|
||||
|
||||
# Verify doc_type is in return_properties
|
||||
call_kwargs = mock_col.query.bm25.call_args
|
||||
return_props = call_kwargs.kwargs.get("return_properties")
|
||||
assert "doc_type" in return_props, (
|
||||
f"doc_type should be in return_properties for BM25 search, got: {return_props}"
|
||||
)
|
||||
|
||||
# Verify doc_type is in result metadata
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata.get("doc_type") == "image"
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
|
||||
"""Test that add_texts includes doc_type from document metadata in stored properties."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# Create a document with doc_type metadata (as produced by multimodal indexing)
|
||||
doc = Document(
|
||||
page_content="an image of a cat",
|
||||
metadata={
|
||||
"doc_id": "upload_file_123",
|
||||
"doc_type": "image",
|
||||
"dataset_id": "ds_1",
|
||||
"document_id": "doc_1",
|
||||
"doc_hash": "hash_xyz",
|
||||
},
|
||||
)
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
# Mock batch context manager
|
||||
mock_batch = MagicMock()
|
||||
mock_batch.__enter__ = MagicMock(return_value=mock_batch)
|
||||
mock_batch.__exit__ = MagicMock(return_value=False)
|
||||
mock_col.batch.dynamic.return_value = mock_batch
|
||||
|
||||
wv.add_texts(documents=[doc], embeddings=[[0.1] * 128])
|
||||
|
||||
# Verify batch.add_object was called with doc_type in properties
|
||||
mock_batch.add_object.assert_called_once()
|
||||
call_kwargs = mock_batch.add_object.call_args
|
||||
stored_props = call_kwargs.kwargs.get("properties")
|
||||
assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}"
|
||||
|
||||
|
||||
class TestVectorDefaultAttributes(unittest.TestCase):
|
||||
"""Tests for Vector class default attributes list."""
|
||||
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
|
||||
def test_default_attributes_include_doc_type(self, mock_init_vector, mock_get_embeddings):
|
||||
"""Test that Vector class default attributes include doc_type."""
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
|
||||
mock_get_embeddings.return_value = MagicMock()
|
||||
mock_init_vector.return_value = MagicMock()
|
||||
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.index_struct_dict = None
|
||||
|
||||
vector = Vector(dataset=mock_dataset)
|
||||
|
||||
assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -104,10 +104,11 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_map_known_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("map error"))
|
||||
mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"}))
|
||||
|
||||
assert app.map("https://example.com") == {}
|
||||
with pytest.raises(Exception, match="map error"):
|
||||
app.map("https://example.com")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_map_unknown_error_raises(self, mocker: MockerFixture):
|
||||
@ -177,10 +178,11 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl error"))
|
||||
mocker.patch("httpx.get", return_value=_response(500, {"error": "server"}))
|
||||
|
||||
assert app.check_crawl_status("job-1") == {}
|
||||
with pytest.raises(Exception, match="crawl error"):
|
||||
app.check_crawl_status("job-1")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture):
|
||||
@ -272,9 +274,10 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_search_known_http_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("search error"))
|
||||
mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"}))
|
||||
assert app.search("python") == {}
|
||||
with pytest.raises(Exception, match="search error"):
|
||||
app.search("python")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_search_unknown_http_error(self, mocker: MockerFixture):
|
||||
|
||||
@ -236,7 +236,8 @@ class TestParagraphIndexProcessor:
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve"
|
||||
) as mock_retrieve:
|
||||
mock_retrieve.return_value = [accepted, rejected]
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
|
||||
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
|
||||
@ -307,7 +307,8 @@ class TestParentChildIndexProcessor:
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve"
|
||||
) as mock_retrieve:
|
||||
mock_retrieve.return_value = [ok_result, low_result]
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {})
|
||||
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "keep"
|
||||
|
||||
@ -262,7 +262,8 @@ class TestQAIndexProcessor:
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve:
|
||||
mock_retrieve.return_value = [result_ok, result_low]
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
|
||||
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "accepted"
|
||||
|
||||
@ -25,6 +25,7 @@ from core.app.app_config.entities import ModelConfig as WorkflowModelConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.rag.data_post_processor.data_post_processor import WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
@ -4686,7 +4687,10 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
extra={"dataset_name": "Ext", "title": "Ext"},
|
||||
)
|
||||
app = Flask(__name__)
|
||||
weights = {"vector_setting": {}}
|
||||
weights: WeightsDict = {
|
||||
"vector_setting": {"vector_weight": 0.5, "embedding_provider_name": "", "embedding_model_name": ""},
|
||||
"keyword_setting": {"keyword_weight": 0.5},
|
||||
}
|
||||
|
||||
def fake_multiple_thread(**kwargs):
|
||||
if kwargs["query"]:
|
||||
|
||||
@ -0,0 +1,677 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
_HumanInputFormEntityImpl,
|
||||
_HumanInputFormRecipientEntityImpl,
|
||||
_InvalidTimeoutStatusError,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _FakeSelect:
|
||||
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
|
||||
|
||||
|
||||
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"form_content": "hi",
|
||||
"inputs": [],
|
||||
"user_actions": [{"id": "submit", "title": "Submit"}],
|
||||
"rendered_content": "<p>hi</p>",
|
||||
}
|
||||
if include_expiration_time:
|
||||
payload["expiration_time"] = naive_utc_now()
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyForm:
|
||||
id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_definition: str
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
|
||||
selected_action_id: str | None = None
|
||||
submitted_data: str | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
completed_by_recipient_id: str | None = None
|
||||
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyRecipient:
|
||||
id: str
|
||||
form_id: str
|
||||
recipient_type: RecipientType
|
||||
access_token: str | None
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def first(self) -> Any:
|
||||
if isinstance(self._obj, list):
|
||||
return self._obj[0] if self._obj else None
|
||||
return self._obj
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
if self._obj is None:
|
||||
return []
|
||||
if isinstance(self._obj, list):
|
||||
return list(self._obj)
|
||||
return [self._obj]
|
||||
|
||||
|
||||
class _FakeExecuteResult:
|
||||
def __init__(self, rows: Sequence[tuple[Any, ...]]):
|
||||
self._rows = list(rows)
|
||||
|
||||
def all(self) -> list[tuple[Any, ...]]:
|
||||
return list(self._rows)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result: Any = None,
|
||||
scalars_results: list[Any] | None = None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
execute_rows: Sequence[tuple[Any, ...]] = (),
|
||||
):
|
||||
if scalars_results is not None:
|
||||
self._scalars_queue = list(scalars_results)
|
||||
else:
|
||||
self._scalars_queue = [scalars_result]
|
||||
self._forms = forms or {}
|
||||
self._recipients = recipients or {}
|
||||
self._execute_rows = list(execute_rows)
|
||||
self.added: list[Any] = []
|
||||
|
||||
def scalars(self, _query: Any) -> _FakeScalarResult:
|
||||
if self._scalars_queue:
|
||||
value = self._scalars_queue.pop(0)
|
||||
else:
|
||||
value = None
|
||||
return _FakeScalarResult(value)
|
||||
|
||||
def execute(self, _stmt: Any) -> _FakeExecuteResult:
|
||||
return _FakeExecuteResult(self._execute_rows)
|
||||
|
||||
def get(self, model_cls: Any, obj_id: str) -> Any:
|
||||
name = getattr(model_cls, "__name__", "")
|
||||
if name == "HumanInputForm":
|
||||
return self._forms.get(obj_id)
|
||||
if name == "HumanInputFormRecipient":
|
||||
return self._recipients.get(obj_id)
|
||||
return None
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def add_all(self, objs: Sequence[Any]) -> None:
|
||||
self.added.extend(list(objs))
|
||||
|
||||
def flush(self) -> None:
|
||||
# Simulate DB default population for attributes referenced in entity wrappers.
|
||||
for obj in self.added:
|
||||
if hasattr(obj, "id") and obj.id in (None, ""):
|
||||
obj.id = f"gen-{len(str(self.added))}"
|
||||
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
|
||||
if obj.recipient_type == RecipientType.CONSOLE:
|
||||
obj.access_token = "token-console"
|
||||
elif obj.recipient_type == RecipientType.BACKSTAGE:
|
||||
obj.access_token = "token-backstage"
|
||||
else:
|
||||
obj.access_token = "token-webapp"
|
||||
|
||||
def refresh(self, _obj: Any) -> None:
|
||||
return None
|
||||
|
||||
def begin(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _SessionFactoryStub:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def create_session(self) -> _FakeSession:
|
||||
return self._session
|
||||
|
||||
|
||||
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
|
||||
|
||||
|
||||
def test_recipient_entity_token_raises_when_missing() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token=None)
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
with pytest.raises(AssertionError, match="access_token should not be None"):
|
||||
_ = entity.token
|
||||
|
||||
|
||||
def test_recipient_entity_id_and_token_success() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token="tok")
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
assert entity.id == "r1"
|
||||
assert entity.token == "tok"
|
||||
|
||||
|
||||
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
|
||||
webapp = _DummyRecipient(
|
||||
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
|
||||
)
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "ctok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "wtok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token is None
|
||||
|
||||
|
||||
def test_form_entity_submitted_data_parsed() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
submitted_data='{"a": 1}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.submitted is True
|
||||
assert entity.submitted_data == {"a": 1}
|
||||
assert entity.rendered_content == "<p>x</p>"
|
||||
assert entity.selected_action_id is None
|
||||
assert entity.status == HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
|
||||
expiration = naive_utc_now()
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=False),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=expiration,
|
||||
submitted_data='{"k": "v"}',
|
||||
)
|
||||
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
|
||||
assert record.definition.expiration_time == expiration
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
assert record.submitted is False
|
||||
|
||||
|
||||
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: list[SimpleNamespace] = []
|
||||
|
||||
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
|
||||
recipient = SimpleNamespace(
|
||||
id=f"{payload.TYPE}-{len(created)}",
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token="tok",
|
||||
)
|
||||
created.append(recipient)
|
||||
return recipient
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
|
||||
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
members=[
|
||||
_WorkspaceMemberInfo(user_id="u1", email=""),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
|
||||
],
|
||||
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
|
||||
)
|
||||
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
|
||||
assert rows == [
|
||||
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
|
||||
]
|
||||
|
||||
|
||||
def test_query_all_workspace_members_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_all_workspace_members(session=session)
|
||||
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
|
||||
def test_repository_init_sets_tenant_id() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._tenant_id == "tenant"
|
||||
|
||||
|
||||
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
result = repo._delivery_method_to_model(
|
||||
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
|
||||
)
|
||||
assert result.delivery.id == "del-1"
|
||||
assert result.delivery.form_id == "form-1"
|
||||
assert len(result.recipients) == 1
|
||||
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
|
||||
|
||||
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
called: dict[str, Any] = {}
|
||||
|
||||
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
|
||||
called.update(
|
||||
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
|
||||
)
|
||||
return ["r"]
|
||||
|
||||
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
subject="s",
|
||||
body="b",
|
||||
)
|
||||
)
|
||||
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
|
||||
assert result.recipients == ["r"]
|
||||
assert called["delivery_id"] == "del-1"
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr(
|
||||
repo,
|
||||
"_query_all_workspace_members",
|
||||
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
|
||||
)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
|
||||
assert restrict_to_user_ids == ["u1"]
|
||||
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo.get_form("run", "node") is None
|
||||
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
entity = repo.get_form("run", "node")
|
||||
assert entity is not None
|
||||
assert entity.id == "f1"
|
||||
assert entity.recipients[0].id == "r1"
|
||||
assert entity.recipients[0].token == "tok"
|
||||
|
||||
|
||||
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
|
||||
|
||||
session = _FakeSession()
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
form_config = HumanInputNodeData(
|
||||
title="Title",
|
||||
delivery_methods=[],
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
params = FormCreateParams(
|
||||
app_id="app",
|
||||
workflow_execution_id="run",
|
||||
node_id="node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>hello</p>",
|
||||
delivery_methods=[WebAppDeliveryMethod()],
|
||||
display_in_ui=True,
|
||||
resolved_default_values={},
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
console_recipient_required=True,
|
||||
console_creator_account_id="acc-1",
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
entity = repo.create_form(params)
|
||||
assert entity.id == "form-id"
|
||||
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
|
||||
# Console token should take precedence when console recipient is present.
|
||||
assert entity.web_app_token == "token-console"
|
||||
assert len(entity.recipients) == 3
|
||||
|
||||
|
||||
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
recipient = SimpleNamespace(form=None)
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
|
||||
def test_submission_repository_init_no_args() -> None:
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert isinstance(repo, HumanInputFormSubmissionRepository)
|
||||
|
||||
|
||||
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = SimpleNamespace(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
form=form,
|
||||
)
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_token("tok")
|
||||
assert record is not None
|
||||
assert record.access_token == "tok"
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
assert record is not None
|
||||
assert record.recipient_id == "r1"
|
||||
|
||||
|
||||
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
|
||||
|
||||
|
||||
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
missing_session = _FakeSession(forms={})
|
||||
_patch_session_factory(monkeypatch, missing_session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_submitted(
|
||||
form_id="missing",
|
||||
recipient_id=None,
|
||||
selected_action_id="a",
|
||||
form_data={},
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=fixed_now,
|
||||
)
|
||||
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
|
||||
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=recipient.id,
|
||||
selected_action_id="approve",
|
||||
form_data={"k": "v"},
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
)
|
||||
assert form.status == HumanInputFormStatus.SUBMITTED
|
||||
assert form.submitted_at == fixed_now
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
|
||||
|
||||
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.TIMEOUT,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
|
||||
assert record.status == HumanInputFormStatus.TIMEOUT
|
||||
|
||||
|
||||
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form already submitted"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
|
||||
|
||||
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
selected_action_id="a",
|
||||
submitted_data="{}",
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
completed_by_recipient_id="r",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
assert form.status == HumanInputFormStatus.EXPIRED
|
||||
assert form.selected_action_id is None
|
||||
assert form.submitted_data is None
|
||||
assert form.submission_user_id is None
|
||||
assert form.submission_end_user_id is None
|
||||
assert form.completed_by_recipient_id is None
|
||||
assert record.status == HumanInputFormStatus.EXPIRED
|
||||
|
||||
|
||||
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)
|
||||
@ -1,84 +1,291 @@
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, WorkflowRun
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from models import Account, CreatorUserRole, EndUser, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = str(uuid4())
|
||||
user.current_tenant_id = str(uuid4())
|
||||
|
||||
repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=real_session_factory,
|
||||
user=user,
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = False
|
||||
repository._session_factory = MagicMock(return_value=session_context)
|
||||
return repository
|
||||
|
||||
|
||||
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
|
||||
return WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "hello"},
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist():
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
session_factory.return_value.__enter__.return_value = session
|
||||
return session_factory
|
||||
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists():
|
||||
session = MagicMock()
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Mock SQLAlchemy Engine."""
|
||||
return MagicMock(spec=Engine)
|
||||
|
||||
execution_id = str(uuid4())
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repository._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
session.get.return_value = existing_run
|
||||
|
||||
execution = _build_execution(
|
||||
execution_id=execution_id,
|
||||
started_at=datetime(2026, 1, 1, 12, 30, 0),
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
outputs={"output1": "result1"},
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
error_message="",
|
||||
total_tokens=100,
|
||||
total_steps=5,
|
||||
exceptions_count=0,
|
||||
started_at=datetime.now(UTC),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
class TestSQLAlchemyWorkflowExecutionRepository:
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
app_id = "test_app_id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
assert repo._session_factory == mock_session_factory
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
|
||||
|
||||
def test_init_with_engine(self, mock_engine, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_engine,
|
||||
user=mock_account,
|
||||
app_id="test_app_id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
assert repo._session_factory.kw["bind"] == mock_engine
|
||||
|
||||
def test_init_invalid_session_factory(self, mock_account):
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_no_tenant_id(self, mock_session_factory):
|
||||
user = MagicMock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
|
||||
)
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
assert repo._creator_user_role == CreatorUserRole.END_USER
|
||||
|
||||
def test_to_domain_model(self, mock_session_factory, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
db_model = MagicMock(spec=WorkflowRun)
|
||||
db_model.id = str(uuid4())
|
||||
db_model.workflow_id = str(uuid4())
|
||||
db_model.type = "workflow"
|
||||
db_model.version = "1.0"
|
||||
db_model.inputs_dict = {"in": "val"}
|
||||
db_model.outputs_dict = {"out": "val"}
|
||||
db_model.graph_dict = {"nodes": []}
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = "some error"
|
||||
db_model.total_tokens = 50
|
||||
db_model.total_steps = 3
|
||||
db_model.exceptions_count = 1
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = datetime.now(UTC)
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert domain_model.id_ == db_model.id
|
||||
assert domain_model.workflow_id == db_model.workflow_id
|
||||
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert domain_model.inputs == db_model.inputs_dict
|
||||
assert domain_model.error_message == "some error"
|
||||
|
||||
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Make elapsed time deterministic to avoid flaky tests
|
||||
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.id == sample_workflow_execution.id_
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
assert db_model.app_id == "test_app"
|
||||
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
assert db_model.status == sample_workflow_execution.status.value
|
||||
assert db_model.total_tokens == sample_workflow_execution.total_tokens
|
||||
assert db_model.elapsed_time == 10.0
|
||||
|
||||
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Test with empty/None fields
|
||||
sample_workflow_execution.graph = None
|
||||
sample_workflow_execution.inputs = None
|
||||
sample_workflow_execution.outputs = None
|
||||
sample_workflow_execution.error_message = None
|
||||
sample_workflow_execution.finished_at = None
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.graph is None
|
||||
assert db_model.inputs is None
|
||||
assert db_model.outputs is None
|
||||
assert db_model.error is None
|
||||
assert db_model.elapsed_time == 0
|
||||
|
||||
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
assert not hasattr(db_model, "app_id") or db_model.app_id is None
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
|
||||
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
# Test triggered_from missing
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._creator_user_id = "some_id"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.merge.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
# Check cache
|
||||
assert sample_workflow_execution.id_ in repo._execution_cache
|
||||
cached_model = repo._execution_cache[sample_workflow_execution.id_]
|
||||
assert cached_model.id == sample_workflow_execution.id_
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.started_at = started_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = None
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution_id = sample_workflow_execution.id_
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repo._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = existing_run
|
||||
|
||||
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
@ -0,0 +1,772 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
_deterministic_json_dump,
|
||||
_filter_by_offload_type,
|
||||
_find_first,
|
||||
_replace_or_append_offload,
|
||||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models import Account, EndUser
|
||||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = user_id
|
||||
user.tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _execution(
|
||||
*,
|
||||
execution_id: str = "exec-id",
|
||||
node_execution_id: str = "node-exec-id",
|
||||
workflow_run_id: str = "run-id",
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
node_execution_id=node_execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Title",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
status=status,
|
||||
error=None,
|
||||
elapsed_time=1.0,
|
||||
metadata=metadata,
|
||||
created_at=datetime.now(UTC),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
|
||||
class _SessionCtx:
|
||||
def __init__(self, session: Any):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(session: Any) -> sessionmaker:
|
||||
factory = Mock(spec=sessionmaker)
|
||||
factory.return_value = _SessionCtx(session)
|
||||
return factory
|
||||
|
||||
|
||||
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
engine: Engine = create_engine("sqlite:///:memory:")
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=engine,
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
|
||||
sm = Mock(spec=sessionmaker)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=sm,
|
||||
user=_mock_end_user(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert repo._creator_user_role.value == "end_user"
|
||||
|
||||
|
||||
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
|
||||
session_factory=object(),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
user = _mock_account()
|
||||
user.current_tenant_id = None
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=user,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: dict[str, Any] = {}
|
||||
|
||||
class FakeTruncator:
|
||||
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
|
||||
created.update(
|
||||
{
|
||||
"max_size_bytes": max_size_bytes,
|
||||
"array_element_limit": array_element_limit,
|
||||
"string_length_limit": string_length_limit,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
|
||||
FakeTruncator,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
_ = repo._create_truncator()
|
||||
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
|
||||
|
||||
|
||||
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
|
||||
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
|
||||
assert _find_first([], lambda _: True) is None
|
||||
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
|
||||
|
||||
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
|
||||
|
||||
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
|
||||
assert len(replaced) == 2
|
||||
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
|
||||
|
||||
|
||||
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
|
||||
|
||||
# Happy path: deterministic json dump should be sorted
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
|
||||
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
|
||||
|
||||
repo._triggered_from = None
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution()
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert db_model.app_id == "app"
|
||||
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
repo._creator_user_id = "user"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_is_duplicate_key_error_and_regenerate_id(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
assert repo._is_duplicate_key_error(duplicate_error) is True
|
||||
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
|
||||
|
||||
execution = _execution(execution_id="old-id")
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "old-id"
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
caplog.set_level(logging.WARNING)
|
||||
repo._regenerate_id_on_duplicate(execution, db_model)
|
||||
assert execution.id == "new-id"
|
||||
assert db_model.id == "new-id"
|
||||
assert any("Duplicate key conflict" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id1"
|
||||
db_model.node_execution_id = "node1"
|
||||
db_model.foo = "bar" # type: ignore[attr-defined]
|
||||
db_model.__dict__["_private"] = "x"
|
||||
|
||||
existing = SimpleNamespace()
|
||||
session.get.return_value = existing
|
||||
repo._persist_to_database(db_model)
|
||||
assert existing.foo == "bar"
|
||||
session.add.assert_not_called()
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
session.reset_mock()
|
||||
session.get.return_value = None
|
||||
repo._node_execution_cache.clear()
|
||||
repo._persist_to_database(db_model)
|
||||
session.add.assert_called_once_with(db_model)
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
|
||||
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return value, False
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
|
||||
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
uploaded: dict[str, Any] = {}
|
||||
|
||||
class FakeFileService:
|
||||
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
|
||||
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
|
||||
return SimpleNamespace(id="file-id", key="file-key")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return {"truncated": True}, True
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
|
||||
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
|
||||
assert result is not None
|
||||
assert result.truncated_value == {"truncated": True}
|
||||
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
|
||||
assert result.offload.file_id == "file-id"
|
||||
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
|
||||
|
||||
|
||||
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"trunc": "i"})
|
||||
db_model.process_data = json.dumps({"trunc": "p"})
|
||||
db_model.outputs = json.dumps({"trunc": "o"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = json.dumps({"total_tokens": 3})
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
|
||||
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
|
||||
off_in.file = SimpleNamespace(key="k-in")
|
||||
off_out.file = SimpleNamespace(key="k-out")
|
||||
off_proc.file = SimpleNamespace(key="k-proc")
|
||||
db_model.offload_data = [off_out, off_in, off_proc]
|
||||
|
||||
def fake_load(key: str) -> bytes:
|
||||
return json.dumps({"full": key}).encode()
|
||||
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"full": "k-in"}
|
||||
assert domain.outputs == {"full": "k-out"}
|
||||
assert domain.process_data == {"full": "k-proc"}
|
||||
assert domain.get_truncated_inputs() == {"trunc": "i"}
|
||||
assert domain.get_truncated_outputs() == {"trunc": "o"}
|
||||
assert domain.get_truncated_process_data() == {"trunc": "p"}
|
||||
|
||||
|
||||
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"i": 1})
|
||||
db_model.process_data = json.dumps({"p": 2})
|
||||
db_model.outputs = json.dumps({"o": 3})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
db_model.offload_data = []
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"i": 1}
|
||||
assert domain.outputs == {"o": 3}
|
||||
|
||||
|
||||
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakeConverter:
|
||||
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return {"wrapped": values["a"]}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
|
||||
FakeConverter,
|
||||
)
|
||||
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
|
||||
|
||||
|
||||
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
trunc_result = SimpleNamespace(
|
||||
truncated_value={"trunc": True},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
|
||||
)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
# Inputs should be truncated, outputs/process_data encoded directly
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.inputs) == {"trunc": True}
|
||||
assert json.loads(db_model.outputs) == {"b": 2}
|
||||
assert json.loads(db_model.process_data) == {"c": 3}
|
||||
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
|
||||
assert execution.get_truncated_inputs() == {"trunc": True}
|
||||
|
||||
|
||||
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
existing = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = existing
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
|
||||
if values == {"b": 2}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"b": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
|
||||
)
|
||||
if values == {"c": 3}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"c": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
|
||||
)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.outputs) == {"b": "trunc"}
|
||||
assert json.loads(db_model.process_data) == {"c": "trunc"}
|
||||
assert execution.get_truncated_outputs() == {"b": "trunc"}
|
||||
assert execution.get_truncated_process_data() == {"c": "trunc"}
|
||||
|
||||
|
||||
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1})
|
||||
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
|
||||
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
merged = session.merge.call_args.args[0]
|
||||
assert merged.inputs == '{"a": 1}'
|
||||
|
||||
|
||||
def test_save_retries_duplicate_and_logs_non_duplicate(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(execution_id="id")
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
other_error = IntegrityError("other", params=None, orig=None)
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
def persist(_db_model: Any) -> None:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise duplicate_error
|
||||
|
||||
monkeypatch.setattr(repo, "_persist_to_database", persist)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
repo.save(execution)
|
||||
assert execution.id == "new-id"
|
||||
assert repo._node_execution_cache[execution.node_execution_id] is not None
|
||||
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
|
||||
with pytest.raises(IntegrityError):
|
||||
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
|
||||
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_save_logs_and_reraises_on_unexpected_error(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
|
||||
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def __init__(self) -> None:
|
||||
self.where_calls = 0
|
||||
self.order_by_args: tuple[Any, ...] | None = None
|
||||
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
self.where_calls += 1
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.order_by_args = args
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
model1 = SimpleNamespace(node_execution_id="n1")
|
||||
model2 = SimpleNamespace(node_execution_id=None)
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [model1, model2]
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
|
||||
db_models = repo.get_db_models_by_workflow_run("run", order)
|
||||
assert db_models == [model1, model2]
|
||||
assert repo._node_execution_cache["n1"] is model1
|
||||
assert stmt.order_by_args is not None
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.args = args # type: ignore[attr-defined]
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
|
||||
|
||||
|
||||
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
|
||||
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
|
||||
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
|
||||
|
||||
class FakeExecutor:
|
||||
def __enter__(self) -> FakeExecutor:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
|
||||
assert timeout == 30
|
||||
return list(map(func, items))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
|
||||
lambda max_workers: FakeExecutor(),
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("run", order_config=None)
|
||||
assert result == ["domain:db1", "domain:db2"]
|
||||
137
api/tests/unit_tests/core/schemas/test_registry.py
Normal file
137
api/tests/unit_tests/core/schemas/test_registry.py
Normal file
@ -0,0 +1,137 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
||||
class TestSchemaRegistry:
|
||||
def test_initialization(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
assert registry.base_dir == base_dir
|
||||
assert registry.versions == {}
|
||||
assert registry.metadata == {}
|
||||
|
||||
def test_default_registry_singleton(self):
|
||||
registry1 = SchemaRegistry.default_registry()
|
||||
registry2 = SchemaRegistry.default_registry()
|
||||
assert registry1 is registry2
|
||||
assert isinstance(registry1, SchemaRegistry)
|
||||
|
||||
def test_load_all_versions_non_existent_dir(self, tmp_path):
|
||||
base_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
registry.load_all_versions()
|
||||
assert registry.versions == {}
|
||||
|
||||
def test_load_all_versions_filtering(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
(base_dir / "not_a_version_dir").mkdir()
|
||||
(base_dir / "v1").mkdir()
|
||||
(base_dir / "some_file.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
with patch.object(registry, "_load_version_dir") as mock_load:
|
||||
registry.load_all_versions()
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][0] == "v1"
|
||||
|
||||
def test_load_version_dir_filtering(self, tmp_path):
|
||||
version_dir = tmp_path / "v1"
|
||||
version_dir.mkdir()
|
||||
(version_dir / "schema1.json").write_text("{}")
|
||||
(version_dir / "not_a_schema.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
with patch.object(registry, "_load_schema") as mock_load:
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][1] == "schema1"
|
||||
|
||||
def test_load_version_dir_non_existent(self, tmp_path):
|
||||
version_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
assert "v1" not in registry.versions
|
||||
|
||||
def test_load_schema_success(self, tmp_path):
|
||||
schema_path = tmp_path / "test.json"
|
||||
schema_content = {"title": "Test Schema", "description": "A test schema"}
|
||||
schema_path.write_text(json.dumps(schema_content))
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "test", schema_path)
|
||||
|
||||
assert registry.versions["v1"]["test"] == schema_content
|
||||
uri = "https://dify.ai/schemas/v1/test.json"
|
||||
assert registry.metadata[uri]["title"] == "Test Schema"
|
||||
assert registry.metadata[uri]["version"] == "v1"
|
||||
|
||||
def test_load_schema_invalid_json(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "invalid.json"
|
||||
schema_path.write_text("invalid json")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "invalid", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/invalid" in caplog.text
|
||||
|
||||
def test_load_schema_os_error(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "error.json"
|
||||
schema_path.write_text("{}")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
|
||||
with patch("builtins.open", side_effect=OSError("Read error")):
|
||||
registry._load_schema("v1", "error", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/error" in caplog.text
|
||||
|
||||
def test_get_schema(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"type": "object"}}}
|
||||
|
||||
# Valid URI
|
||||
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
|
||||
|
||||
# Invalid URI
|
||||
assert registry.get_schema("invalid-uri") is None
|
||||
|
||||
# Missing version
|
||||
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
|
||||
|
||||
def test_list_versions(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v2": {}, "v1": {}}
|
||||
assert registry.list_versions() == ["v1", "v2"]
|
||||
|
||||
def test_list_schemas(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"b": {}, "a": {}}}
|
||||
|
||||
assert registry.list_schemas("v1") == ["a", "b"]
|
||||
assert registry.list_schemas("v2") == []
|
||||
|
||||
def test_get_all_schemas_for_version(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
|
||||
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "test"
|
||||
assert results[0]["label"] == "Test Label"
|
||||
assert results[0]["schema"] == {"title": "Test Label"}
|
||||
|
||||
# Default label if title missing
|
||||
registry.versions["v1"]["no_title"] = {}
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
item = next(r for r in results if r["name"] == "no_title")
|
||||
assert item["label"] == "no_title"
|
||||
|
||||
# Empty if version missing
|
||||
assert registry.get_all_schemas_for_version("v2") == []
|
||||
80
api/tests/unit_tests/core/schemas/test_schema_manager.py
Normal file
80
api/tests/unit_tests/core/schemas/test_schema_manager.py
Normal file
@ -0,0 +1,80 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
|
||||
|
||||
def test_init_with_provided_registry():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
|
||||
def test_init_with_default_registry(mock_default_registry):
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_default_registry.return_value = mock_registry
|
||||
|
||||
manager = SchemaManager()
|
||||
|
||||
mock_default_registry.assert_called_once()
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
def test_get_all_schema_definitions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
|
||||
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_all_schema_definitions(version="v2")
|
||||
|
||||
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
|
||||
assert result == expected_definitions
|
||||
|
||||
|
||||
def test_get_schema_by_name_success():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_schema = {"type": "object"}
|
||||
mock_registry.get_schema.return_value = mock_schema
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("my_schema", version="v1")
|
||||
|
||||
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
|
||||
mock_registry.get_schema.assert_called_once_with(expected_uri)
|
||||
assert result == {"name": "my_schema", "schema": mock_schema}
|
||||
|
||||
|
||||
def test_get_schema_by_name_not_found():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_registry.get_schema.return_value = None
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("non_existent", version="v1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_available_schemas():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_schemas = ["schema1", "schema2"]
|
||||
mock_registry.list_schemas.return_value = expected_schemas
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_schemas(version="v1")
|
||||
|
||||
mock_registry.list_schemas.assert_called_once_with("v1")
|
||||
assert result == expected_schemas
|
||||
|
||||
|
||||
def test_list_available_versions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_versions = ["v1", "v2"]
|
||||
mock_registry.list_versions.return_value = expected_versions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_versions()
|
||||
|
||||
mock_registry.list_versions.assert_called_once()
|
||||
assert result == expected_versions
|
||||
@ -1,32 +1,34 @@
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker: MockerFixture):
|
||||
mock_entity = mocker.Mock()
|
||||
def mock_provider_entity():
|
||||
mock_entity = Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||
provider_credential_schema = mocker.Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
provider_credential_schema = Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
|
||||
mock_entity.provider_credential_schema = provider_credential_schema
|
||||
|
||||
model_credential_schema = mocker.Mock()
|
||||
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
model_credential_schema = Mock()
|
||||
type(model_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
|
||||
mock_entity.model_credential_schema = model_credential_schema
|
||||
|
||||
return mock_entity
|
||||
|
||||
|
||||
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
def test__to_model_settings(mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
@ -63,18 +65,18 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@ -87,7 +89,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||
def test__to_model_settings_only_one_lb(mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
|
||||
ps = ProviderModelSetting(
|
||||
@ -113,18 +115,18 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@ -135,7 +137,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||
def test__to_model_settings_lb_disabled(mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
@ -170,18 +172,18 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@ -190,3 +192,39 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test_get_default_model_uses_first_available_active_model():
|
||||
mock_session = Mock()
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
provider_configurations = Mock()
|
||||
provider_configurations.get_models.return_value = [
|
||||
Mock(model="gpt-3.5-turbo", provider=Mock(provider="openai")),
|
||||
Mock(model="gpt-4", provider=Mock(provider="openai")),
|
||||
]
|
||||
|
||||
manager = ProviderManager()
|
||||
with (
|
||||
patch("core.provider_manager.db.session", mock_session),
|
||||
patch.object(manager, "get_configurations", return_value=provider_configurations),
|
||||
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
|
||||
):
|
||||
mock_factory_cls.return_value.get_provider_schema.return_value = Mock(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
assert result is not None
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
assert result.provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
mock_session.add.assert_called_once()
|
||||
saved_default_model = mock_session.add.call_args.args[0]
|
||||
assert saved_default_model.model_name == "gpt-3.5-turbo"
|
||||
assert saved_default_model.provider_name == "openai"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode
|
||||
from dify_graph.nodes.http_request import HttpRequestNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
@ -68,6 +68,8 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
||||
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
|
||||
|
||||
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
|
||||
@ -4,9 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import dify_graph.graph_engine.response_coordinator.session as response_session_module
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType
|
||||
from dify_graph.graph_engine.response_coordinator import RESPONSE_SESSION_NODE_TYPES
|
||||
from dify_graph.graph_engine.response_coordinator.session import ResponseSession
|
||||
from dify_graph.nodes.base.template import Template, TextSegment
|
||||
|
||||
@ -35,28 +33,14 @@ class DummyNodeWithoutStreamingTemplate:
|
||||
self.state = NodeState.UNKNOWN
|
||||
|
||||
|
||||
def test_response_session_from_node_rejects_node_types_outside_allowlist() -> None:
|
||||
"""Unsupported node types are rejected even if they expose a template."""
|
||||
def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None:
|
||||
"""Session creation depends on the streaming-template contract rather than node type."""
|
||||
node = DummyResponseNode(
|
||||
node_id="llm-node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
template=Template(segments=[TextSegment(text="hello")]),
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="RESPONSE_SESSION_NODE_TYPES"):
|
||||
ResponseSession.from_node(node)
|
||||
|
||||
|
||||
def test_response_session_from_node_supports_downstream_allowlist_extension(monkeypatch) -> None:
|
||||
"""Downstream applications can extend the supported node-type list."""
|
||||
node = DummyResponseNode(
|
||||
node_id="llm-node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
template=Template(segments=[TextSegment(text="hello")]),
|
||||
)
|
||||
extended_node_types = [*RESPONSE_SESSION_NODE_TYPES, BuiltinNodeTypes.LLM]
|
||||
monkeypatch.setattr(response_session_module, "RESPONSE_SESSION_NODE_TYPES", extended_node_types)
|
||||
|
||||
session = ResponseSession.from_node(node)
|
||||
|
||||
assert session.node_id == "llm-node"
|
||||
|
||||
145
api/tests/unit_tests/core/workflow/graph_engine/test_worker.py
Normal file
145
api/tests/unit_tests/core/workflow/graph_engine/test_worker.py
Normal file
@ -0,0 +1,145 @@
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from dify_graph.graph_engine.worker import Worker
|
||||
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
|
||||
|
||||
|
||||
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
|
||||
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=InMemoryReadyQueue(),
|
||||
event_queue=queue.Queue(),
|
||||
graph=MagicMock(),
|
||||
layers=[],
|
||||
)
|
||||
node = SimpleNamespace(
|
||||
execution_id="exec-1",
|
||||
id="node-1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
)
|
||||
|
||||
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
|
||||
|
||||
assert event.start_at == fixed_time
|
||||
assert event.finished_at == fixed_time
|
||||
assert event.error == "boom"
|
||||
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert event.node_run_result.error == "boom"
|
||||
assert event.node_run_result.error_type == "RuntimeError"
|
||||
|
||||
|
||||
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
|
||||
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
failure_time = start_at + timedelta(seconds=5)
|
||||
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
|
||||
|
||||
class FakeNode:
|
||||
execution_id = "exec-1"
|
||||
id = "node-1"
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
|
||||
yield NodeRunStartedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
node_title="LLM",
|
||||
start_at=start_at,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=MagicMock(),
|
||||
event_queue=MagicMock(),
|
||||
graph=MagicMock(nodes={"node-1": FakeNode()}),
|
||||
layers=[],
|
||||
)
|
||||
|
||||
worker._ready_queue.get.side_effect = ["node-1"]
|
||||
|
||||
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
|
||||
captured_events.append(event)
|
||||
if len(captured_events) == 1:
|
||||
raise RuntimeError("queue boom")
|
||||
worker.stop()
|
||||
|
||||
worker._event_queue.put.side_effect = put_side_effect
|
||||
|
||||
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
|
||||
worker.run()
|
||||
|
||||
fallback_event = captured_events[-1]
|
||||
|
||||
assert isinstance(fallback_event, NodeRunFailedEvent)
|
||||
assert fallback_event.start_at == start_at
|
||||
assert fallback_event.finished_at == failure_time
|
||||
assert fallback_event.error == "queue boom"
|
||||
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
|
||||
|
||||
def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None:
|
||||
parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
child_start = parent_start + timedelta(seconds=3)
|
||||
failure_time = parent_start + timedelta(seconds=5)
|
||||
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
|
||||
|
||||
class FakeIterationNode:
|
||||
execution_id = "iteration-exec"
|
||||
id = "iteration-node"
|
||||
node_type = BuiltinNodeTypes.ITERATION
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
|
||||
yield NodeRunStartedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
node_title="Iteration",
|
||||
start_at=parent_start,
|
||||
)
|
||||
yield NodeRunStartedEvent(
|
||||
id="child-exec",
|
||||
node_id="child-node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
node_title="LLM",
|
||||
start_at=child_start,
|
||||
in_iteration_id=self.id,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=MagicMock(),
|
||||
event_queue=MagicMock(),
|
||||
graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}),
|
||||
layers=[],
|
||||
)
|
||||
|
||||
worker._ready_queue.get.side_effect = ["iteration-node"]
|
||||
|
||||
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
|
||||
captured_events.append(event)
|
||||
if len(captured_events) == 2:
|
||||
raise RuntimeError("queue boom")
|
||||
worker.stop()
|
||||
|
||||
worker._event_queue.put.side_effect = put_side_effect
|
||||
|
||||
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
|
||||
worker.run()
|
||||
|
||||
fallback_event = captured_events[-1]
|
||||
|
||||
assert isinstance(fallback_event, NodeRunFailedEvent)
|
||||
assert fallback_event.start_at == parent_start
|
||||
assert fallback_event.finished_at == failure_time
|
||||
@ -14,3 +14,64 @@ def test_render_body_template_replaces_variable_values():
|
||||
result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool)
|
||||
|
||||
assert result == "Hello World https://example.com"
|
||||
|
||||
|
||||
def test_render_markdown_body_renders_markdown_to_html():
|
||||
rendered = EmailDeliveryConfig.render_markdown_body("**Bold** and [link](https://example.com)")
|
||||
|
||||
assert "<strong>Bold</strong>" in rendered
|
||||
assert '<a href="https://example.com">link</a>' in rendered
|
||||
|
||||
|
||||
def test_render_markdown_body_sanitizes_unsafe_html():
|
||||
rendered = EmailDeliveryConfig.render_markdown_body(
|
||||
'<script>alert("xss")</script><a href="javascript:alert(1)" onclick="alert(2)">Click</a>'
|
||||
)
|
||||
|
||||
assert "<script" not in rendered
|
||||
assert "<a" not in rendered
|
||||
assert "onclick" not in rendered
|
||||
assert "javascript:" not in rendered
|
||||
assert "Click" in rendered
|
||||
|
||||
|
||||
def test_render_markdown_body_sanitizes_markdown_link_with_javascript_href():
|
||||
rendered = EmailDeliveryConfig.render_markdown_body("[bad](javascript:alert(1)) and [ok](https://example.com)")
|
||||
|
||||
assert "javascript:" not in rendered
|
||||
assert "<a>bad</a>" in rendered
|
||||
assert '<a href="https://example.com">ok</a>' in rendered
|
||||
|
||||
|
||||
def test_render_markdown_body_does_not_allow_raw_html_tags():
|
||||
rendered = EmailDeliveryConfig.render_markdown_body("<b>raw html</b> and **markdown**")
|
||||
|
||||
assert "<b>" not in rendered
|
||||
assert "raw html" in rendered
|
||||
assert "<strong>markdown</strong>" in rendered
|
||||
|
||||
|
||||
def test_render_markdown_body_supports_table_syntax():
|
||||
rendered = EmailDeliveryConfig.render_markdown_body("| h1 | h2 |\n| --- | ---: |\n| v1 | v2 |")
|
||||
|
||||
assert "<table>" in rendered
|
||||
assert "<thead>" in rendered
|
||||
assert "<tbody>" in rendered
|
||||
assert 'align="right"' in rendered
|
||||
assert "style=" not in rendered
|
||||
|
||||
|
||||
def test_sanitize_subject_removes_crlf():
|
||||
sanitized = EmailDeliveryConfig.sanitize_subject("Notice\r\nBCC:attacker@example.com")
|
||||
|
||||
assert "\r" not in sanitized
|
||||
assert "\n" not in sanitized
|
||||
assert sanitized == "Notice BCC:attacker@example.com"
|
||||
|
||||
|
||||
def test_sanitize_subject_removes_html_tags():
|
||||
sanitized = EmailDeliveryConfig.sanitize_subject("<b>Alert</b><img src=x onerror=1>")
|
||||
|
||||
assert "<" not in sanitized
|
||||
assert ">" not in sanitized
|
||||
assert sanitized == "Alert"
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.graph_events import NodeRunSucceededEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from dify_graph.nodes.iteration.iteration_node import IterationNode
|
||||
|
||||
|
||||
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Parallel Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["iteration", "output"],
|
||||
is_parallel=True,
|
||||
parallel_nums=2,
|
||||
error_handle_mode=ErrorHandleMode.TERMINATED,
|
||||
)
|
||||
node._capture_execution_context = lambda: nullcontext()
|
||||
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
|
||||
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
|
||||
|
||||
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
|
||||
return (
|
||||
0.1 + (index * 0.1),
|
||||
[
|
||||
NodeRunSucceededEvent(
|
||||
id=f"exec-{index}",
|
||||
node_id=f"llm-{index}",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
start_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
),
|
||||
],
|
||||
f"output-{item}",
|
||||
{},
|
||||
LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
|
||||
|
||||
outputs: list[object] = []
|
||||
iter_run_map: dict[str, float] = {}
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
generator = node._execute_parallel_iterations(
|
||||
iterator_list_value=["a", "b"],
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
for _ in generator:
|
||||
# Simulate a slow consumer replaying buffered events.
|
||||
time.sleep(0.02)
|
||||
|
||||
assert outputs == ["output-a", "output-b"]
|
||||
assert iter_run_map["0"] == pytest.approx(0.1)
|
||||
assert iter_run_map["1"] == pytest.approx(0.2)
|
||||
@ -1,18 +1,26 @@
|
||||
"""Tests for llm_utils module, specifically multimodal content handling."""
|
||||
"""Tests for llm_utils module, specifically multimodal content handling and prompt message construction."""
|
||||
|
||||
import string
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.nodes.llm.llm_utils import (
|
||||
_truncate_multimodal_content,
|
||||
build_context,
|
||||
restore_multimodal_content_in_messages,
|
||||
)
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
class TestTruncateMultimodalContent:
|
||||
@ -50,7 +58,6 @@ class TestTruncateMultimodalContent:
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
assert result_content.base64_data == ""
|
||||
assert result_content.url == ""
|
||||
# file_ref should be preserved
|
||||
assert result_content.file_ref == "local:test-file-id"
|
||||
|
||||
def test_truncates_base64_when_no_file_ref(self):
|
||||
@ -70,7 +77,6 @@ class TestTruncateMultimodalContent:
|
||||
assert isinstance(result.content, list)
|
||||
result_content = result.content[0]
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
# Should be truncated with marker
|
||||
assert "...[TRUNCATED]..." in result_content.base64_data
|
||||
assert len(result_content.base64_data) < len(long_base64)
|
||||
|
||||
@ -89,9 +95,7 @@ class TestTruncateMultimodalContent:
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 2
|
||||
# Text content unchanged
|
||||
assert result.content[0].data == "Hello!"
|
||||
# Image content base64 cleared
|
||||
assert result.content[1].base64_data == ""
|
||||
|
||||
|
||||
@ -100,8 +104,6 @@ class TestBuildContext:
|
||||
|
||||
def test_excludes_system_messages(self):
|
||||
"""System messages should be excluded from context."""
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Hello!"),
|
||||
@ -109,7 +111,6 @@ class TestBuildContext:
|
||||
|
||||
context = build_context(messages, "Hi there!")
|
||||
|
||||
# Should have user message + assistant response, no system message
|
||||
assert len(context) == 2
|
||||
assert context[0].content == "Hello!"
|
||||
assert context[1].content == "Hi there!"
|
||||
@ -140,7 +141,6 @@ class TestBuildContext:
|
||||
|
||||
messages = [UserPromptMessage(content="What's the weather in Beijing?")]
|
||||
|
||||
# Create trace with tool call and result
|
||||
generation_data = LLMGenerationData(
|
||||
text="The weather in Beijing is sunny, 25°C.",
|
||||
reasoning_contents=[],
|
||||
@ -183,7 +183,6 @@ class TestBuildContext:
|
||||
accumulated_response = "Let me check the weather.The weather in Beijing is sunny, 25°C."
|
||||
context = build_context(messages, accumulated_response, generation_data)
|
||||
|
||||
# Should have: user message + assistant with tool_call + tool result + final assistant
|
||||
assert len(context) == 4
|
||||
assert context[0].content == "What's the weather in Beijing?"
|
||||
assert isinstance(context[1], AssistantPromptMessage)
|
||||
@ -223,7 +222,6 @@ class TestBuildContext:
|
||||
finish_reason="stop",
|
||||
files=[],
|
||||
trace=[
|
||||
# First model call with two tool calls
|
||||
LLMTraceSegment(
|
||||
type="model",
|
||||
duration=0.5,
|
||||
@ -237,7 +235,6 @@ class TestBuildContext:
|
||||
],
|
||||
),
|
||||
),
|
||||
# First tool result
|
||||
LLMTraceSegment(
|
||||
type="tool",
|
||||
duration=0.2,
|
||||
@ -249,7 +246,6 @@ class TestBuildContext:
|
||||
output="Sunny, 25°C",
|
||||
),
|
||||
),
|
||||
# Second tool result
|
||||
LLMTraceSegment(
|
||||
type="tool",
|
||||
duration=0.2,
|
||||
@ -267,7 +263,6 @@ class TestBuildContext:
|
||||
accumulated_response = "I'll check both cities.Beijing is sunny at 25°C, Shanghai is cloudy at 22°C."
|
||||
context = build_context(messages, accumulated_response, generation_data)
|
||||
|
||||
# Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant
|
||||
assert len(context) == 5
|
||||
assert context[0].content == "Compare weather in Beijing and Shanghai"
|
||||
assert isinstance(context[1], AssistantPromptMessage)
|
||||
@ -304,12 +299,11 @@ class TestBuildContext:
|
||||
usage=LLMUsage.empty_usage(),
|
||||
finish_reason="stop",
|
||||
files=[],
|
||||
trace=[], # Empty trace
|
||||
trace=[],
|
||||
)
|
||||
|
||||
context = build_context(messages, "Hi there!", generation_data)
|
||||
|
||||
# Should fallback to simple context
|
||||
assert len(context) == 2
|
||||
assert context[0].content == "Hello!"
|
||||
assert context[1].content == "Hi there!"
|
||||
@ -321,7 +315,6 @@ class TestRestoreMultimodalContentInMessages:
|
||||
@patch("dify_graph.file.file_manager.restore_multimodal_content")
|
||||
def test_restores_multimodal_content(self, mock_restore):
|
||||
"""Should restore multimodal content in messages."""
|
||||
# Setup mock
|
||||
restored_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="restored-base64",
|
||||
@ -330,7 +323,6 @@ class TestRestoreMultimodalContentInMessages:
|
||||
)
|
||||
mock_restore.return_value = restored_content
|
||||
|
||||
# Create message with truncated content
|
||||
truncated_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
@ -363,3 +355,98 @@ class TestRestoreMultimodalContentInMessages:
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].data == "Hello!"
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
prompt_template = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="You are a classifier.",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
|
||||
return_value=mock.MagicMock(features=[]),
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
|
||||
return_value=[SystemPromptMessage(content=content)],
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=["END"],
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=None,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
@ -107,6 +107,7 @@ def llm_node(
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@ -121,6 +122,7 @@ def llm_node(
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node
|
||||
@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
|
||||
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
|
||||
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
|
||||
messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="",
|
||||
jinja2_text="Hello, {{ name }}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="jinja2",
|
||||
)
|
||||
]
|
||||
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=None,
|
||||
jinja2_variables=[],
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
template_renderer=llm_node._template_renderer,
|
||||
)
|
||||
|
||||
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
|
||||
llm_node._template_renderer.render_jinja2.assert_called_once_with(
|
||||
template="Hello, {{ name }}",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
|
||||
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
memory.get_history_prompt_messages.return_value = [
|
||||
@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=3),
|
||||
)
|
||||
|
||||
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = llm_utils.handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
@ -1,5 +1,14 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.question_classifier import (
|
||||
QuestionClassifierNode,
|
||||
QuestionClassifierNodeData,
|
||||
)
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data():
|
||||
@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
|
||||
assert node_data.vision.enabled == False
|
||||
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
|
||||
node_data = QuestionClassifierNodeData.model_validate(
|
||||
{
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
}
|
||||
)
|
||||
template_renderer = MagicMock(spec=TemplateRenderer)
|
||||
node = QuestionClassifierNode(
|
||||
id="node-id",
|
||||
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="workflow-id",
|
||||
graph_config={},
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
),
|
||||
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
llm_file_saver=MagicMock(),
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
fetch_prompt_messages = MagicMock(return_value=([], None))
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
|
||||
fetch_prompt_messages,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
|
||||
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
|
||||
)
|
||||
|
||||
node._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query="hello",
|
||||
model_instance=MagicMock(stop=(), parameters={}),
|
||||
context="",
|
||||
)
|
||||
|
||||
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = build_test_graph_init_params(
|
||||
graph_config=graph_config,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", files=[]),
|
||||
user_inputs={"payload": "value"},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def _build_node_config() -> NodeConfigDict:
|
||||
return NodeConfigDictAdapter.validate_python(
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": TRIGGER_PLUGIN_NODE_TYPE,
|
||||
"title": "Trigger Event",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"subscription_id": "subscription-id",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
"event_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
id="node-1",
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
}
|
||||
@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
|
||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||
|
||||
|
||||
class TestDefaultLLMTemplateRenderer:
|
||||
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
|
||||
renderer = node_factory.DefaultLLMTemplateRenderer()
|
||||
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
|
||||
monkeypatch.setattr(
|
||||
node_factory.CodeExecutor,
|
||||
"execute_workflow_code_template",
|
||||
execute_workflow_code_template,
|
||||
)
|
||||
|
||||
result = renderer.render_jinja2(
|
||||
template="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
assert result == "hello world"
|
||||
execute_workflow_code_template.assert_called_once_with(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
|
||||
http_request_config = sentinel.http_request_config
|
||||
credentials_provider = sentinel.credentials_provider
|
||||
model_factory = sentinel.model_factory
|
||||
llm_template_renderer = sentinel.llm_template_renderer
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
|
||||
"build_http_request_config",
|
||||
return_value=http_request_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"DefaultLLMTemplateRenderer",
|
||||
return_value=llm_template_renderer,
|
||||
) as llm_renderer_factory,
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_dify_model_access",
|
||||
@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
|
||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||
renderer_factory.assert_called_once()
|
||||
llm_renderer_factory.assert_called_once()
|
||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||
assert factory.graph_init_params is graph_init_params
|
||||
assert factory.graph_runtime_state is graph_runtime_state
|
||||
assert factory._dify_context is dify_context
|
||||
assert factory._template_renderer is template_renderer
|
||||
|
||||
assert factory._llm_template_renderer is llm_template_renderer
|
||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||
assert factory._http_request_config is http_request_config
|
||||
assert factory._llm_credentials_provider is credentials_provider
|
||||
@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
factory._code_executor = sentinel.code_executor
|
||||
factory._code_limits = sentinel.code_limits
|
||||
factory._template_renderer = sentinel.template_renderer
|
||||
factory._llm_template_renderer = sentinel.llm_template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||
@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
[
|
||||
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
||||
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
||||
(
|
||||
BuiltinNodeTypes.LLM,
|
||||
"LLMNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
"QuestionClassifierNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user