mirror of
https://github.com/langgenius/dify.git
synced 2026-02-23 03:17:57 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -91,6 +91,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
"pool_size": 30,
|
||||
"pool_use_lifo": False,
|
||||
"pool_reset_on_return": None,
|
||||
"pool_timeout": 30,
|
||||
}
|
||||
|
||||
assert config["CONSOLE_WEB_URL"] == "https://example.com"
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import Any, NamedTuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import marshal
|
||||
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
@ -9,11 +11,14 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
_serialize_full_content,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
_TEST_APP_ID = "test_app_id"
|
||||
@ -21,6 +26,54 @@ _TEST_NODE_EXEC_ID = str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableFields:
|
||||
def test_serialize_full_content(self):
|
||||
"""Test that _serialize_full_content uses pre-loaded relationships."""
|
||||
# Create mock objects with relationships pre-loaded
|
||||
mock_variable_file = MagicMock(spec=WorkflowDraftVariableFile)
|
||||
mock_variable_file.size = 100000
|
||||
mock_variable_file.length = 50
|
||||
mock_variable_file.value_type = SegmentType.OBJECT
|
||||
mock_variable_file.upload_file_id = "test-upload-file-id"
|
||||
|
||||
mock_variable = MagicMock(spec=WorkflowDraftVariable)
|
||||
mock_variable.file_id = "test-file-id"
|
||||
mock_variable.variable_file = mock_variable_file
|
||||
|
||||
# Mock the file helpers
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
|
||||
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
|
||||
|
||||
# Call the function
|
||||
result = _serialize_full_content(mock_variable)
|
||||
|
||||
# Verify it returns the expected structure
|
||||
assert result is not None
|
||||
assert result["size_bytes"] == 100000
|
||||
assert result["length"] == 50
|
||||
assert result["value_type"] == "object"
|
||||
assert "download_url" in result
|
||||
assert result["download_url"] == "http://example.com/signed-url"
|
||||
|
||||
# Verify it used the pre-loaded relationships (no database queries)
|
||||
mock_file_helpers.get_signed_file_url.assert_called_once_with("test-upload-file-id", as_attachment=True)
|
||||
|
||||
def test_serialize_full_content_handles_none_cases(self):
|
||||
"""Test that _serialize_full_content handles None cases properly."""
|
||||
|
||||
# Test with no file_id
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.file_id = None
|
||||
result = _serialize_full_content(draft_var)
|
||||
assert result is None
|
||||
|
||||
def test_serialize_full_content_should_raises_when_file_id_exists_but_file_is_none(self):
|
||||
# Test with no file_id
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.file_id = str(uuid.uuid4())
|
||||
draft_var.variable_file = None
|
||||
with pytest.raises(AssertionError):
|
||||
result = _serialize_full_content(draft_var)
|
||||
|
||||
def test_conversation_variable(self):
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||
@ -39,12 +92,14 @@ class TestWorkflowDraftVariableFields:
|
||||
"value_type": "number",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = 1
|
||||
expected_with_value["full_content"] = None
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_create_sys_variable(self):
|
||||
@ -70,11 +125,13 @@ class TestWorkflowDraftVariableFields:
|
||||
"value_type": "string",
|
||||
"edited": True,
|
||||
"visible": True,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = "a"
|
||||
expected_with_value["full_content"] = None
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_node_variable(self):
|
||||
@ -100,14 +157,65 @@ class TestWorkflowDraftVariableFields:
|
||||
"value_type": "array[any]",
|
||||
"edited": True,
|
||||
"visible": False,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = [1, "a"]
|
||||
expected_with_value["full_content"] = None
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_node_variable_with_file(self):
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="node_var",
|
||||
value=build_segment([1, "a"]),
|
||||
visible=False,
|
||||
node_execution_id=_TEST_NODE_EXEC_ID,
|
||||
)
|
||||
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var.last_edited_at = naive_utc_now()
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.ARRAY_STRING,
|
||||
)
|
||||
node_var.variable_file = variable_file
|
||||
node_var.file_id = variable_file.id
|
||||
|
||||
expected_without_value: OrderedDict[str, Any] = OrderedDict(
|
||||
{
|
||||
"id": str(node_var.id),
|
||||
"type": node_var.get_variable_type().value,
|
||||
"name": "node_var",
|
||||
"description": "",
|
||||
"selector": ["test_node", "node_var"],
|
||||
"value_type": "array[any]",
|
||||
"edited": True,
|
||||
"visible": False,
|
||||
"is_truncated": True,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
|
||||
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = [1, "a"]
|
||||
expected_with_value["full_content"] = {
|
||||
"size_bytes": 1024,
|
||||
"value_type": "array[string]",
|
||||
"length": 10,
|
||||
"download_url": "http://example.com/signed-url",
|
||||
}
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableList:
|
||||
def test_workflow_draft_variable_list(self):
|
||||
@ -135,6 +243,7 @@ class TestWorkflowDraftVariableList:
|
||||
"value_type": "string",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from flask_restx import Api
|
||||
import services.errors.account
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
from controllers.console.auth.login import LoginApi
|
||||
from controllers.console.error import AccountNotFound
|
||||
|
||||
|
||||
class TestAuthenticationSecurity:
|
||||
@ -27,31 +26,33 @@ class TestAuthenticationSecurity:
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email sends reset password email when registration is allowed."""
|
||||
"""Test that invalid email raises AuthenticationFailedError when account not found."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
result = login_api.post()
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "fail", "data": "token123", "code": "account_not_found"}
|
||||
mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US")
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@ -87,16 +88,17 @@ class TestAuthenticationSecurity:
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email raises AccountNotFound when registration is disabled."""
|
||||
"""Test that invalid email raises AuthenticationFailedError when account not found."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
@ -107,10 +109,12 @@ class TestAuthenticationSecurity:
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AccountNotFound) as exc_info:
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "account_not_found"
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.console.auth.oauth import (
|
||||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from models.account import AccountStatus
|
||||
from services.errors.account import AccountNotFoundError
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
class TestGetOAuthProviders:
|
||||
@ -201,9 +201,9 @@ class TestOAuthCallback:
|
||||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
# Import the real requests module to create a proper exception
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
request_exception = requests.exceptions.RequestException("OAuth error")
|
||||
request_exception = httpx.RequestError("OAuth error")
|
||||
request_exception.response = MagicMock()
|
||||
request_exception.response.text = str(exception)
|
||||
|
||||
@ -451,7 +451,7 @@ class TestAccountGeneration:
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
if not allow_register and not existing_account:
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
with pytest.raises(AccountRegisterError):
|
||||
_generate_account("github", user_info)
|
||||
else:
|
||||
result = _generate_account("github", user_info)
|
||||
|
||||
@ -82,6 +82,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.task_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
@ -125,13 +126,18 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
@ -214,6 +220,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.task_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
@ -257,8 +264,10 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
@ -275,6 +284,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
mock_conv_var_class.from_variable.side_effect = mock_conv_vars
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
@ -361,6 +373,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.task_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
@ -396,13 +409,18 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
storage_key="storage_key_123",
|
||||
)
|
||||
|
||||
def create_file_dict(self, file_id: str = "test_file_dict") -> dict:
|
||||
def create_file_dict(self, file_id: str = "test_file_dict"):
|
||||
"""Create a file dictionary with correct dify_model_identity"""
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
|
||||
@ -0,0 +1,430 @@
|
||||
"""
|
||||
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessDataResponseScenario:
|
||||
"""Test scenario for process_data in responses."""
|
||||
|
||||
name: str
|
||||
original_process_data: dict[str, Any] | None
|
||||
truncated_process_data: dict[str, Any] | None
|
||||
expected_response_data: dict[str, Any] | None
|
||||
expected_truncated_flag: bool
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterCenarios:
|
||||
"""Test process_data truncation in WorkflowResponseConverter."""
|
||||
|
||||
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
|
||||
"""Create a mock WorkflowAppGenerateEntity."""
|
||||
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
|
||||
mock_app_config = Mock()
|
||||
mock_app_config.tenant_id = "test-tenant-id"
|
||||
mock_entity.app_config = mock_app_config
|
||||
return mock_entity
|
||||
|
||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||
"""Create a WorkflowResponseConverter for testing."""
|
||||
|
||||
mock_entity = self.create_mock_generate_entity()
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
return WorkflowResponseConverter(application_generate_entity=mock_entity, user=mock_user)
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
truncated_process_data: dict[str, Any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution for testing."""
|
||||
execution = WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(truncated_process_data)
|
||||
|
||||
return execution
|
||||
|
||||
def create_node_succeeded_event(self) -> QueueNodeSucceededEvent:
|
||||
"""Create a QueueNodeSucceededEvent for testing."""
|
||||
return QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
def create_node_retry_event(self) -> QueueNodeRetryEvent:
|
||||
"""Create a QueueNodeRetryEvent for testing."""
|
||||
return QueueNodeRetryEvent(
|
||||
inputs={"data": "inputs"},
|
||||
outputs={"data": "outputs"},
|
||||
error="oops",
|
||||
retry_index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_title="test code",
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
def test_workflow_node_finish_response_uses_truncated_process_data(self):
|
||||
"""Test that node finish response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_finish_response_without_truncation(self):
|
||||
"""Test node finish response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_with_none_process_data(self):
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should have None process_data
|
||||
assert response is not None
|
||||
assert response.data.process_data is None
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
)
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_retry_response_without_truncation(self):
|
||||
"""Test node retry response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_iteration_and_loop_nodes_return_none(self):
|
||||
"""Test that iteration and loop nodes return None (no change from existing behavior)."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
# Test iteration node
|
||||
iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
iteration_execution.node_type = NodeType.ITERATION
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=iteration_execution,
|
||||
)
|
||||
|
||||
# Should return None for iteration nodes
|
||||
assert response is None
|
||||
|
||||
# Test loop node
|
||||
loop_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
loop_execution.node_type = NodeType.LOOP
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=loop_execution,
|
||||
)
|
||||
|
||||
# Should return None for loop nodes
|
||||
assert response is None
|
||||
|
||||
def test_execution_without_workflow_execution_id_returns_none(self):
|
||||
"""Test that executions without workflow_execution_id return None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
execution.workflow_execution_id = None # Single-step debugging
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Should return None for single-step debugging
|
||||
assert response is None
|
||||
|
||||
@staticmethod
|
||||
def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]:
|
||||
"""Create test scenarios for process_data responses."""
|
||||
return [
|
||||
ProcessDataResponseScenario(
|
||||
name="none_process_data",
|
||||
original_process_data=None,
|
||||
truncated_process_data=None,
|
||||
expected_response_data=None,
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="small_process_data_no_truncation",
|
||||
original_process_data={"small": "data"},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={"small": "data"},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="large_process_data_with_truncation",
|
||||
original_process_data={"large": "x" * 10000, "metadata": "info"},
|
||||
truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="empty_process_data",
|
||||
original_process_data={},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="complex_data_with_truncation",
|
||||
original_process_data={
|
||||
"logs": ["entry"] * 1000, # Large array
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
truncated_process_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_response_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_response_scenarios(),
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node finish responses."""
|
||||
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_response_scenarios(),
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node retry responses."""
|
||||
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
@ -83,7 +83,7 @@ def test_client_session_initialize():
|
||||
# Create message handler
|
||||
def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
):
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
@ -28,7 +29,7 @@ class TestHandleMCPRequest:
|
||||
"""Setup test fixtures"""
|
||||
self.app = Mock(spec=App)
|
||||
self.app.name = "test_app"
|
||||
self.app.mode = AppMode.CHAT.value
|
||||
self.app.mode = AppMode.CHAT
|
||||
|
||||
self.mcp_server = Mock(spec=AppMCPServer)
|
||||
self.mcp_server.description = "Test server"
|
||||
@ -195,7 +196,7 @@ class TestIndividualHandlers:
|
||||
def test_handle_list_tools(self):
|
||||
"""Test list tools handler"""
|
||||
app_name = "test_app"
|
||||
app_mode = AppMode.CHAT.value
|
||||
app_mode = AppMode.CHAT
|
||||
description = "Test server"
|
||||
parameters_dict: dict[str, str] = {}
|
||||
user_input_form: list[VariableEntity] = []
|
||||
@ -211,7 +212,7 @@ class TestIndividualHandlers:
|
||||
def test_handle_call_tool(self, mock_app_generate):
|
||||
"""Test call tool handler"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
# Create mock request
|
||||
mock_request = Mock()
|
||||
@ -251,7 +252,7 @@ class TestUtilityFunctions:
|
||||
|
||||
def test_build_parameter_schema_chat_mode(self):
|
||||
"""Test building parameter schema for chat mode"""
|
||||
app_mode = AppMode.CHAT.value
|
||||
app_mode = AppMode.CHAT
|
||||
parameters_dict: dict[str, str] = {"name": "Enter your name"}
|
||||
|
||||
user_input_form = [
|
||||
@ -274,7 +275,7 @@ class TestUtilityFunctions:
|
||||
|
||||
def test_build_parameter_schema_workflow_mode(self):
|
||||
"""Test building parameter schema for workflow mode"""
|
||||
app_mode = AppMode.WORKFLOW.value
|
||||
app_mode = AppMode.WORKFLOW
|
||||
parameters_dict: dict[str, str] = {"input_text": "Enter text"}
|
||||
|
||||
user_input_form = [
|
||||
@ -297,7 +298,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_chat_mode(self):
|
||||
"""Test preparing tool arguments for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
arguments = {"query": "test question", "name": "John"}
|
||||
|
||||
@ -311,7 +312,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_workflow_mode(self):
|
||||
"""Test preparing tool arguments for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
arguments = {"input_text": "test input"}
|
||||
|
||||
@ -323,7 +324,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_completion_mode(self):
|
||||
"""Test preparing tool arguments for completion mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
app.mode = AppMode.COMPLETION
|
||||
|
||||
arguments = {"name": "John"}
|
||||
|
||||
@ -335,7 +336,7 @@ class TestUtilityFunctions:
|
||||
def test_extract_answer_from_mapping_response_chat(self):
|
||||
"""Test extracting answer from mapping response for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
response = {"answer": "test answer", "other": "data"}
|
||||
|
||||
@ -346,7 +347,7 @@ class TestUtilityFunctions:
|
||||
def test_extract_answer_from_mapping_response_workflow(self):
|
||||
"""Test extracting answer from mapping response for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
response = {"data": {"outputs": {"result": "test result"}}}
|
||||
|
||||
@ -434,7 +435,7 @@ class TestUtilityFunctions:
|
||||
assert parameters["category"]["enum"] == ["A", "B", "C"]
|
||||
|
||||
assert "count" in parameters
|
||||
assert parameters["count"]["type"] == "float"
|
||||
assert parameters["count"]["type"] == "number"
|
||||
|
||||
# FILE type should be skipped - it creates empty dict but gets filtered later
|
||||
# Check that it doesn't have any meaningful content
|
||||
@ -447,3 +448,65 @@ class TestUtilityFunctions:
|
||||
assert "category" not in required
|
||||
|
||||
# Note: _get_request_id function has been removed as request_id is now passed as parameter
|
||||
|
||||
def test_convert_input_form_to_parameters_jsonschema_validation_ok(self):
|
||||
"""Current schema uses 'number' for numeric fields; it should be a valid JSON Schema."""
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.NUMBER,
|
||||
variable="count",
|
||||
description="Count",
|
||||
label="Count",
|
||||
required=True,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
variable="name",
|
||||
description="User name",
|
||||
label="Name",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict = {
|
||||
"count": "Enter count",
|
||||
"name": "Enter your name",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
# Build a complete JSON Schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
# 1) The schema itself must be valid
|
||||
jsonschema.Draft202012Validator.check_schema(schema)
|
||||
|
||||
# 2) Both float and integer instances should pass validation
|
||||
jsonschema.validate(instance={"count": 3.14, "name": "alice"}, schema=schema)
|
||||
jsonschema.validate(instance={"count": 2, "name": "bob"}, schema=schema)
|
||||
|
||||
def test_legacy_float_type_schema_is_invalid(self):
|
||||
"""Legacy/buggy behavior: using 'float' should produce an invalid JSON Schema."""
|
||||
# Manually construct a legacy/incorrect schema (simulating old behavior)
|
||||
bad_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "float", # Invalid type: JSON Schema does not support 'float'
|
||||
"description": "Enter count",
|
||||
}
|
||||
},
|
||||
"required": ["count"],
|
||||
}
|
||||
|
||||
# The schema itself should raise a SchemaError
|
||||
with pytest.raises(jsonschema.exceptions.SchemaError):
|
||||
jsonschema.Draft202012Validator.check_schema(bad_schema)
|
||||
|
||||
# Or validation should also raise SchemaError
|
||||
with pytest.raises(jsonschema.exceptions.SchemaError):
|
||||
jsonschema.validate(instance={"count": 1.23}, schema=bad_schema)
|
||||
|
||||
@ -20,7 +20,6 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
print(f"job_id: {job_id}")
|
||||
|
||||
assert job_id is not None
|
||||
assert isinstance(job_id, str)
|
||||
|
||||
@ -15,7 +15,7 @@ from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
|
||||
@ -0,0 +1,210 @@
|
||||
"""Unit tests for workflow node execution conflict handling."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionConflictHandling:
|
||||
"""Test cases for handling duplicate key conflicts in workflow node execution."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create a mock user with tenant_id
|
||||
self.mock_user = Mock(spec=Account)
|
||||
self.mock_user.id = "test-user-id"
|
||||
self.mock_user.current_tenant_id = "test-tenant-id"
|
||||
|
||||
# Create mock session factory
|
||||
self.mock_session_factory = Mock(spec=sessionmaker)
|
||||
|
||||
# Create repository instance
|
||||
self.repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=self.mock_session_factory,
|
||||
user=self.mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
def test_save_with_duplicate_key_retries_with_new_uuid(self):
|
||||
"""Test that save retries with a new UUID v7 when encountering duplicate key error."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Create IntegrityError for duplicate key with proper psycopg2.errors.UniqueViolation
|
||||
mock_unique_violation = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError(
|
||||
"duplicate key value violates unique constraint",
|
||||
params=None,
|
||||
orig=mock_unique_violation,
|
||||
)
|
||||
|
||||
# First call to session.add raises IntegrityError, second succeeds
|
||||
mock_session.add.side_effect = [duplicate_error, None]
|
||||
mock_session.commit.side_effect = [None, None]
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="original-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
original_id = execution.id
|
||||
|
||||
# Save should succeed after retry
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was called twice (initial attempt + retry)
|
||||
assert mock_session.add.call_count == 2
|
||||
|
||||
# Verify that the ID was changed (new UUID v7 generated)
|
||||
assert execution.id != original_id
|
||||
|
||||
def test_save_with_existing_record_updates_instead_of_insert(self):
|
||||
"""Test that save updates existing record instead of inserting duplicate."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock existing record
|
||||
mock_existing = MagicMock()
|
||||
mock_session.get.return_value = mock_existing
|
||||
mock_session.commit.return_value = None
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="existing-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save should update existing record
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was not called (update path)
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# Verify that session.commit was called
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_save_exceeds_max_retries_raises_error(self):
|
||||
"""Test that save raises error after exceeding max retries."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Create IntegrityError for duplicate key with proper psycopg2.errors.UniqueViolation
|
||||
mock_unique_violation = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError(
|
||||
"duplicate key value violates unique constraint",
|
||||
params=None,
|
||||
orig=mock_unique_violation,
|
||||
)
|
||||
|
||||
# All attempts fail with duplicate error
|
||||
mock_session.add.side_effect = duplicate_error
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save should raise IntegrityError after max retries
|
||||
with pytest.raises(IntegrityError):
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was called 3 times (max_retries)
|
||||
assert mock_session.add.call_count == 3
|
||||
|
||||
def test_save_non_duplicate_integrity_error_raises_immediately(self):
|
||||
"""Test that non-duplicate IntegrityErrors are raised immediately without retry."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Create IntegrityError for non-duplicate constraint
|
||||
other_error = IntegrityError(
|
||||
"null value in column violates not-null constraint",
|
||||
params=None,
|
||||
orig=None,
|
||||
)
|
||||
|
||||
# First call raises non-duplicate error
|
||||
mock_session.add.side_effect = other_error
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save should raise error immediately
|
||||
with pytest.raises(IntegrityError):
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was called only once (no retry)
|
||||
assert mock_session.add.call_count == 1
|
||||
@ -0,0 +1,217 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecution truncation functionality.
|
||||
|
||||
Tests the truncation and offloading logic for large inputs and outputs
|
||||
in the SQLAlchemyWorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from models import Account, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationTestCase:
|
||||
"""Test case data for truncation scenarios."""
|
||||
|
||||
name: str
|
||||
inputs: dict[str, Any] | None
|
||||
outputs: dict[str, Any] | None
|
||||
should_truncate_inputs: bool
|
||||
should_truncate_outputs: bool
|
||||
description: str
|
||||
|
||||
|
||||
def create_test_cases() -> list[TruncationTestCase]:
|
||||
"""Create test cases for different truncation scenarios."""
|
||||
# Create large data that will definitely exceed the threshold (10KB)
|
||||
large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1000)}
|
||||
small_data = {"data": "small"}
|
||||
|
||||
return [
|
||||
TruncationTestCase(
|
||||
name="small_data_no_truncation",
|
||||
inputs=small_data,
|
||||
outputs=small_data,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=False,
|
||||
description="Small data should not be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_inputs_truncation",
|
||||
inputs=large_data,
|
||||
outputs=small_data,
|
||||
should_truncate_inputs=True,
|
||||
should_truncate_outputs=False,
|
||||
description="Large inputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_outputs_truncation",
|
||||
inputs=small_data,
|
||||
outputs=large_data,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=True,
|
||||
description="Large outputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_both_truncation",
|
||||
inputs=large_data,
|
||||
outputs=large_data,
|
||||
should_truncate_inputs=True,
|
||||
should_truncate_outputs=True,
|
||||
description="Both large inputs and outputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="none_inputs_outputs",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=False,
|
||||
description="None inputs and outputs should not be truncated",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_workflow_node_execution(
|
||||
execution_id: str = "test-execution-id",
|
||||
inputs: dict[str, Any] | None = None,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Factory function to create a WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
def mock_user() -> Account:
|
||||
"""Create a mock Account user for testing."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "test-user-id"
|
||||
user.current_tenant_id = "test-tenant-id"
|
||||
return user
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
|
||||
"""Test class for truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
def create_repository(self) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""Create a repository instance for testing."""
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=MagicMock(spec=Engine),
|
||||
user=mock_user(),
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model correctly handles models without offload data."""
|
||||
repo = self.create_repository()
|
||||
|
||||
# Create a mock database model without offload data
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "test-id"
|
||||
db_model.node_execution_id = "node-exec-id"
|
||||
db_model.workflow_id = "workflow-id"
|
||||
db_model.workflow_run_id = "run-id"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node-id"
|
||||
db_model.node_type = NodeType.LLM.value
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps({"value": "inputs"})
|
||||
db_model.process_data = json.dumps({"value": "process_data"})
|
||||
db_model.outputs = json.dumps({"value": "outputs"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.0
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
db_model.offload_data = []
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
# Check that no truncated data was set
|
||||
assert domain_model.get_truncated_inputs() is None
|
||||
assert domain_model.get_truncated_outputs() is None
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModelTruncatedProperties:
|
||||
"""Test the truncated properties on WorkflowNodeExecutionModel."""
|
||||
|
||||
def test_inputs_truncated_with_offload_data(self):
|
||||
"""Test inputs_truncated property when offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
model.offload_data = [offload]
|
||||
|
||||
assert model.inputs_truncated is True
|
||||
assert model.process_data_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
|
||||
def test_outputs_truncated_with_offload_data(self):
|
||||
"""Test outputs_truncated property when offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
|
||||
# Mock offload data with outputs file
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
model.offload_data = [offload]
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
assert model.outputs_truncated is True
|
||||
|
||||
def test_process_data_truncated_with_offload_data(self):
|
||||
model = WorkflowNodeExecutionModel()
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
|
||||
model.offload_data = [offload]
|
||||
assert model.process_data_truncated is True
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
|
||||
def test_truncated_properties_without_offload_data(self):
|
||||
"""Test truncated properties when no offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
model.offload_data = []
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
|
||||
def test_truncated_properties_without_offload_attribute(self):
|
||||
"""Test truncated properties when offload_data attribute doesn't exist."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
# Don't set offload_data attribute at all
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
1
api/tests/unit_tests/core/schemas/__init__.py
Normal file
1
api/tests/unit_tests/core/schemas/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Core schemas unit tests
|
||||
769
api/tests/unit_tests/core/schemas/test_resolver.py
Normal file
769
api/tests/unit_tests/core/schemas/test_resolver.py
Normal file
@ -0,0 +1,769 @@
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.schemas import resolve_dify_schema_refs
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
from core.schemas.resolver import (
|
||||
MaxDepthExceededError,
|
||||
SchemaResolver,
|
||||
_has_dify_refs,
|
||||
_has_dify_refs_hybrid,
|
||||
_has_dify_refs_recursive,
|
||||
_is_dify_schema_ref,
|
||||
_remove_metadata_fields,
|
||||
parse_dify_schema_uri,
|
||||
)
|
||||
|
||||
|
||||
class TestSchemaResolver:
|
||||
"""Test cases for schema reference resolution"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup method to initialize test resources"""
|
||||
self.registry = SchemaRegistry.default_registry()
|
||||
# Clear cache before each test
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
def test_simple_ref_resolution(self):
|
||||
"""Test resolving a simple $ref to a complete schema"""
|
||||
schema_with_ref = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
|
||||
resolved = resolve_dify_schema_refs(schema_with_ref)
|
||||
|
||||
# Should be resolved to the actual qa_structure schema
|
||||
assert resolved["type"] == "object"
|
||||
assert resolved["title"] == "Q&A Structure"
|
||||
assert "qa_chunks" in resolved["properties"]
|
||||
assert resolved["properties"]["qa_chunks"]["type"] == "array"
|
||||
|
||||
# Metadata fields should be removed
|
||||
assert "$id" not in resolved
|
||||
assert "$schema" not in resolved
|
||||
assert "version" not in resolved
|
||||
|
||||
def test_nested_object_with_refs(self):
|
||||
"""Test resolving $refs within nested object structures"""
|
||||
nested_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
"metadata": {"type": "string", "description": "Additional metadata"},
|
||||
},
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(nested_schema)
|
||||
|
||||
# Original structure should be preserved
|
||||
assert resolved["type"] == "object"
|
||||
assert "metadata" in resolved["properties"]
|
||||
assert resolved["properties"]["metadata"]["type"] == "string"
|
||||
|
||||
# $ref should be resolved
|
||||
file_schema = resolved["properties"]["file_data"]
|
||||
assert file_schema["type"] == "object"
|
||||
assert file_schema["title"] == "File"
|
||||
assert "name" in file_schema["properties"]
|
||||
|
||||
# Metadata fields should be removed from resolved schema
|
||||
assert "$id" not in file_schema
|
||||
assert "$schema" not in file_schema
|
||||
assert "version" not in file_schema
|
||||
|
||||
def test_array_items_ref_resolution(self):
|
||||
"""Test resolving $refs in array items"""
|
||||
array_schema = {
|
||||
"type": "array",
|
||||
"items": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"},
|
||||
"description": "Array of general structures",
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(array_schema)
|
||||
|
||||
# Array structure should be preserved
|
||||
assert resolved["type"] == "array"
|
||||
assert resolved["description"] == "Array of general structures"
|
||||
|
||||
# Items $ref should be resolved
|
||||
items_schema = resolved["items"]
|
||||
assert items_schema["type"] == "array"
|
||||
assert items_schema["title"] == "General Structure"
|
||||
|
||||
def test_non_dify_ref_unchanged(self):
|
||||
"""Test that non-Dify $refs are left unchanged"""
|
||||
external_ref_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"external_data": {"$ref": "https://example.com/external-schema.json"},
|
||||
"dify_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
},
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(external_ref_schema)
|
||||
|
||||
# External $ref should remain unchanged
|
||||
assert resolved["properties"]["external_data"]["$ref"] == "https://example.com/external-schema.json"
|
||||
|
||||
# Dify $ref should be resolved
|
||||
assert resolved["properties"]["dify_data"]["type"] == "object"
|
||||
assert resolved["properties"]["dify_data"]["title"] == "File"
|
||||
|
||||
def test_no_refs_schema_unchanged(self):
|
||||
"""Test that schemas without $refs are returned unchanged"""
|
||||
simple_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name field"},
|
||||
"items": {"type": "array", "items": {"type": "number"}},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(simple_schema)
|
||||
|
||||
# Should be identical to input
|
||||
assert resolved == simple_schema
|
||||
assert resolved["type"] == "object"
|
||||
assert resolved["properties"]["name"]["type"] == "string"
|
||||
assert resolved["properties"]["items"]["items"]["type"] == "number"
|
||||
assert resolved["required"] == ["name"]
|
||||
|
||||
def test_recursion_depth_protection(self):
|
||||
"""Test that excessive recursion depth is prevented"""
|
||||
# Create a moderately nested structure
|
||||
deep_schema = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
|
||||
# Wrap it in fewer layers to make the test more reasonable
|
||||
for _ in range(2):
|
||||
deep_schema = {"type": "object", "properties": {"nested": deep_schema}}
|
||||
|
||||
# Should handle normal cases fine with reasonable depth
|
||||
resolved = resolve_dify_schema_refs(deep_schema, max_depth=25)
|
||||
assert resolved is not None
|
||||
assert resolved["type"] == "object"
|
||||
|
||||
# Should raise error with very low max_depth
|
||||
with pytest.raises(MaxDepthExceededError) as exc_info:
|
||||
resolve_dify_schema_refs(deep_schema, max_depth=5)
|
||||
assert exc_info.value.max_depth == 5
|
||||
|
||||
def test_circular_reference_detection(self):
|
||||
"""Test that circular references are detected and handled"""
|
||||
# Mock registry with circular reference
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_schema.side_effect = lambda uri: {
|
||||
"$ref": "https://dify.ai/schemas/v1/circular.json",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/circular.json"}
|
||||
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
|
||||
|
||||
# Should mark circular reference
|
||||
assert "$circular_ref" in resolved
|
||||
|
||||
def test_schema_not_found_handling(self):
|
||||
"""Test handling of missing schemas"""
|
||||
# Mock registry that returns None for unknown schemas
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_schema.return_value = None
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/unknown.json"}
|
||||
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
|
||||
|
||||
# Should keep the original $ref when schema not found
|
||||
assert resolved["$ref"] == "https://dify.ai/schemas/v1/unknown.json"
|
||||
|
||||
def test_primitive_types_unchanged(self):
|
||||
"""Test that primitive types are returned unchanged"""
|
||||
assert resolve_dify_schema_refs("string") == "string"
|
||||
assert resolve_dify_schema_refs(123) == 123
|
||||
assert resolve_dify_schema_refs(True) is True
|
||||
assert resolve_dify_schema_refs(None) is None
|
||||
assert resolve_dify_schema_refs(3.14) == 3.14
|
||||
|
||||
def test_cache_functionality(self):
|
||||
"""Test that caching works correctly"""
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
|
||||
# First resolution should fetch from registry
|
||||
resolved1 = resolve_dify_schema_refs(schema)
|
||||
|
||||
# Mock the registry to return different data
|
||||
with patch.object(self.registry, "get_schema") as mock_get:
|
||||
mock_get.return_value = {"type": "different"}
|
||||
|
||||
# Second resolution should use cache
|
||||
resolved2 = resolve_dify_schema_refs(schema)
|
||||
|
||||
# Should be the same as first resolution (from cache)
|
||||
assert resolved1 == resolved2
|
||||
# Mock should not have been called
|
||||
mock_get.assert_not_called()
|
||||
|
||||
# Clear cache and try again
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
# Now it should fetch again
|
||||
resolved3 = resolve_dify_schema_refs(schema)
|
||||
assert resolved3 == resolved1
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Test that the resolver is thread-safe"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(10)},
|
||||
}
|
||||
|
||||
results = []
|
||||
|
||||
def resolve_in_thread():
|
||||
try:
|
||||
result = resolve_dify_schema_refs(schema)
|
||||
results.append(result)
|
||||
return True
|
||||
except Exception as e:
|
||||
results.append(e)
|
||||
return False
|
||||
|
||||
# Run multiple threads concurrently
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(resolve_in_thread) for _ in range(20)]
|
||||
success = all(f.result() for f in futures)
|
||||
|
||||
assert success
|
||||
# All results should be the same
|
||||
first_result = results[0]
|
||||
assert all(r == first_result for r in results if not isinstance(r, Exception))
|
||||
|
||||
def test_mixed_nested_structures(self):
|
||||
"""Test resolving refs in complex mixed structures"""
|
||||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}},
|
||||
"nested": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"qa": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"general": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(complex_schema, max_depth=20)
|
||||
|
||||
# Check structure is preserved
|
||||
assert resolved["type"] == "object"
|
||||
assert "files" in resolved["properties"]
|
||||
assert "nested" in resolved["properties"]
|
||||
|
||||
# Check refs are resolved
|
||||
assert resolved["properties"]["files"]["items"]["type"] == "object"
|
||||
assert resolved["properties"]["files"]["items"]["title"] == "File"
|
||||
assert resolved["properties"]["nested"]["properties"]["qa"]["type"] == "object"
|
||||
assert resolved["properties"]["nested"]["properties"]["qa"]["title"] == "Q&A Structure"
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions"""
|
||||
|
||||
def test_is_dify_schema_ref(self):
|
||||
"""Test _is_dify_schema_ref function"""
|
||||
# Valid Dify refs
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v1/file.json")
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v2/complex_name.json")
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v999/test-file.json")
|
||||
|
||||
# Invalid refs
|
||||
assert not _is_dify_schema_ref("https://example.com/schema.json")
|
||||
assert not _is_dify_schema_ref("https://dify.ai/other/path.json")
|
||||
assert not _is_dify_schema_ref("not a uri")
|
||||
assert not _is_dify_schema_ref("")
|
||||
assert not _is_dify_schema_ref(None)
|
||||
assert not _is_dify_schema_ref(123)
|
||||
assert not _is_dify_schema_ref(["list"])
|
||||
|
||||
def test_has_dify_refs(self):
|
||||
"""Test _has_dify_refs function"""
|
||||
# Schemas with Dify refs
|
||||
assert _has_dify_refs({"$ref": "https://dify.ai/schemas/v1/file.json"})
|
||||
assert _has_dify_refs(
|
||||
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}}
|
||||
)
|
||||
assert _has_dify_refs([{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/file.json"}])
|
||||
assert _has_dify_refs(
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Schemas without Dify refs
|
||||
assert not _has_dify_refs({"type": "string"})
|
||||
assert not _has_dify_refs(
|
||||
{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}}
|
||||
)
|
||||
assert not _has_dify_refs(
|
||||
[{"type": "string"}, {"type": "number"}, {"type": "object", "properties": {"name": {"type": "string"}}}]
|
||||
)
|
||||
|
||||
# Schemas with non-Dify refs (should return False)
|
||||
assert not _has_dify_refs({"$ref": "https://example.com/schema.json"})
|
||||
assert not _has_dify_refs(
|
||||
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}}
|
||||
)
|
||||
|
||||
# Primitive types
|
||||
assert not _has_dify_refs("string")
|
||||
assert not _has_dify_refs(123)
|
||||
assert not _has_dify_refs(True)
|
||||
assert not _has_dify_refs(None)
|
||||
|
||||
def test_has_dify_refs_hybrid_vs_recursive(self):
|
||||
"""Test that hybrid and recursive detection give same results"""
|
||||
test_schemas = [
|
||||
# No refs
|
||||
{"type": "string"},
|
||||
{"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
[{"type": "string"}, {"type": "number"}],
|
||||
# With Dify refs
|
||||
{"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}},
|
||||
[{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}],
|
||||
# With non-Dify refs
|
||||
{"$ref": "https://example.com/schema.json"},
|
||||
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}},
|
||||
# Complex nested
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level2": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
# Edge cases
|
||||
{"description": "This mentions $ref but is not a reference"},
|
||||
{"$ref": "not-a-url"},
|
||||
# Primitive types
|
||||
"string",
|
||||
123,
|
||||
True,
|
||||
None,
|
||||
[],
|
||||
]
|
||||
|
||||
for schema in test_schemas:
|
||||
hybrid_result = _has_dify_refs_hybrid(schema)
|
||||
recursive_result = _has_dify_refs_recursive(schema)
|
||||
|
||||
assert hybrid_result == recursive_result, f"Mismatch for schema: {schema}"
|
||||
|
||||
def test_parse_dify_schema_uri(self):
|
||||
"""Test parse_dify_schema_uri function"""
|
||||
# Valid URIs
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v1/file.json") == ("v1", "file")
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v2/complex_name.json") == ("v2", "complex_name")
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v999/test-file.json") == ("v999", "test-file")
|
||||
|
||||
# Invalid URIs
|
||||
assert parse_dify_schema_uri("https://example.com/schema.json") == ("", "")
|
||||
assert parse_dify_schema_uri("invalid") == ("", "")
|
||||
assert parse_dify_schema_uri("") == ("", "")
|
||||
|
||||
def test_remove_metadata_fields(self):
|
||||
"""Test _remove_metadata_fields function"""
|
||||
schema = {
|
||||
"$id": "should be removed",
|
||||
"$schema": "should be removed",
|
||||
"version": "should be removed",
|
||||
"type": "object",
|
||||
"title": "should remain",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
cleaned = _remove_metadata_fields(schema)
|
||||
|
||||
assert "$id" not in cleaned
|
||||
assert "$schema" not in cleaned
|
||||
assert "version" not in cleaned
|
||||
assert cleaned["type"] == "object"
|
||||
assert cleaned["title"] == "should remain"
|
||||
assert "properties" in cleaned
|
||||
|
||||
# Original should be unchanged
|
||||
assert "$id" in schema
|
||||
|
||||
|
||||
class TestSchemaResolverClass:
|
||||
"""Test SchemaResolver class specifically"""
|
||||
|
||||
def test_resolver_initialization(self):
|
||||
"""Test resolver initialization"""
|
||||
# Default initialization
|
||||
resolver = SchemaResolver()
|
||||
assert resolver.max_depth == 10
|
||||
assert resolver.registry is not None
|
||||
|
||||
# Custom initialization
|
||||
custom_registry = MagicMock()
|
||||
resolver = SchemaResolver(registry=custom_registry, max_depth=5)
|
||||
assert resolver.max_depth == 5
|
||||
assert resolver.registry is custom_registry
|
||||
|
||||
def test_cache_sharing(self):
|
||||
"""Test that cache is shared between resolver instances"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
|
||||
# First resolver populates cache
|
||||
resolver1 = SchemaResolver()
|
||||
result1 = resolver1.resolve(schema)
|
||||
|
||||
# Second resolver should use the same cache
|
||||
resolver2 = SchemaResolver()
|
||||
with patch.object(resolver2.registry, "get_schema") as mock_get:
|
||||
result2 = resolver2.resolve(schema)
|
||||
# Should not call registry since it's in cache
|
||||
mock_get.assert_not_called()
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_resolver_with_list_schema(self):
|
||||
"""Test resolver with list as root schema"""
|
||||
list_schema = [
|
||||
{"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
{"type": "string"},
|
||||
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
|
||||
]
|
||||
|
||||
resolver = SchemaResolver()
|
||||
resolved = resolver.resolve(list_schema)
|
||||
|
||||
assert isinstance(resolved, list)
|
||||
assert len(resolved) == 3
|
||||
assert resolved[0]["type"] == "object"
|
||||
assert resolved[0]["title"] == "File"
|
||||
assert resolved[1] == {"type": "string"}
|
||||
assert resolved[2]["type"] == "object"
|
||||
assert resolved[2]["title"] == "Q&A Structure"
|
||||
|
||||
def test_cache_performance(self):
|
||||
"""Test that caching improves performance"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
# Create a schema with many references to the same schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
for i in range(50) # Reduced to avoid depth issues
|
||||
},
|
||||
}
|
||||
|
||||
# First run (no cache) - run multiple times to warm up
|
||||
results1 = []
|
||||
for _ in range(3):
|
||||
SchemaResolver.clear_cache()
|
||||
start = time.perf_counter()
|
||||
result1 = resolve_dify_schema_refs(schema)
|
||||
time_no_cache = time.perf_counter() - start
|
||||
results1.append(time_no_cache)
|
||||
|
||||
avg_time_no_cache = sum(results1) / len(results1)
|
||||
|
||||
# Second run (with cache) - run multiple times
|
||||
results2 = []
|
||||
for _ in range(3):
|
||||
start = time.perf_counter()
|
||||
result2 = resolve_dify_schema_refs(schema)
|
||||
time_with_cache = time.perf_counter() - start
|
||||
results2.append(time_with_cache)
|
||||
|
||||
avg_time_with_cache = sum(results2) / len(results2)
|
||||
|
||||
# Cache should make it faster (more lenient check)
|
||||
assert result1 == result2
|
||||
# Cache should provide some performance benefit (allow for measurement variance)
|
||||
# We expect cache to be faster, but allow for small timing variations
|
||||
performance_ratio = avg_time_with_cache / avg_time_no_cache if avg_time_no_cache > 0 else 1.0
|
||||
assert performance_ratio <= 2.0, f"Cache performance degraded too much: {performance_ratio}"
|
||||
|
||||
def test_fast_path_performance_no_refs(self):
|
||||
"""Test that schemas without $refs use fast path and avoid deep copying"""
|
||||
# Create a moderately complex schema without any $refs (typical plugin output_schema)
|
||||
no_refs_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"property_{i}": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
"items": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
for i in range(50)
|
||||
},
|
||||
}
|
||||
|
||||
# Measure fast path (no refs) performance
|
||||
fast_times = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
result_fast = resolve_dify_schema_refs(no_refs_schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
fast_times.append(elapsed)
|
||||
|
||||
avg_fast_time = sum(fast_times) / len(fast_times)
|
||||
|
||||
# Most importantly: result should be identical to input (no copying)
|
||||
assert result_fast is no_refs_schema
|
||||
|
||||
# Create schema with $refs for comparison (same structure size)
|
||||
with_refs_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"property_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
for i in range(20) # Fewer to avoid depth issues but still comparable
|
||||
},
|
||||
}
|
||||
|
||||
# Measure slow path (with refs) performance
|
||||
SchemaResolver.clear_cache()
|
||||
slow_times = []
|
||||
for _ in range(10):
|
||||
SchemaResolver.clear_cache()
|
||||
start = time.perf_counter()
|
||||
result_slow = resolve_dify_schema_refs(with_refs_schema, max_depth=50)
|
||||
elapsed = time.perf_counter() - start
|
||||
slow_times.append(elapsed)
|
||||
|
||||
avg_slow_time = sum(slow_times) / len(slow_times)
|
||||
|
||||
# The key benefit: fast path should be reasonably fast (main goal is no deep copy)
|
||||
# and definitely avoid the expensive BFS resolution
|
||||
# Even if detection has some overhead, it should still be faster for typical cases
|
||||
print(f"Fast path (no refs): {avg_fast_time:.6f}s")
|
||||
print(f"Slow path (with refs): {avg_slow_time:.6f}s")
|
||||
|
||||
# More lenient check: fast path should be at least somewhat competitive
|
||||
# The main benefit is avoiding deep copy and BFS, not necessarily being 5x faster
|
||||
assert avg_fast_time < avg_slow_time * 2 # Should not be more than 2x slower
|
||||
|
||||
def test_batch_processing_performance(self):
|
||||
"""Test performance improvement for batch processing of schemas without refs"""
|
||||
# Simulate the plugin tool scenario: many schemas, most without refs
|
||||
schemas_without_refs = [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {f"field_{j}": {"type": "string" if j % 2 else "number"} for j in range(10)},
|
||||
}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
# Test batch processing performance
|
||||
start = time.perf_counter()
|
||||
results = [resolve_dify_schema_refs(schema) for schema in schemas_without_refs]
|
||||
batch_time = time.perf_counter() - start
|
||||
|
||||
# Verify all results are identical to inputs (fast path used)
|
||||
for original, result in zip(schemas_without_refs, results):
|
||||
assert result is original
|
||||
|
||||
# Should be very fast - each schema should take < 0.001 seconds on average
|
||||
avg_time_per_schema = batch_time / len(schemas_without_refs)
|
||||
assert avg_time_per_schema < 0.001
|
||||
|
||||
def test_has_dify_refs_performance(self):
|
||||
"""Test that _has_dify_refs is fast for large schemas without refs"""
|
||||
# Create a very large schema without refs
|
||||
large_schema = {"type": "object", "properties": {}}
|
||||
|
||||
# Add many nested properties
|
||||
current = large_schema
|
||||
for i in range(100):
|
||||
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
|
||||
current = current["properties"][f"level_{i}"]
|
||||
|
||||
# _has_dify_refs should be fast even for large schemas
|
||||
times = []
|
||||
for _ in range(50):
|
||||
start = time.perf_counter()
|
||||
has_refs = _has_dify_refs(large_schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
|
||||
# Should be False and fast
|
||||
assert not has_refs
|
||||
assert avg_time < 0.01 # Should complete in less than 10ms
|
||||
|
||||
def test_hybrid_vs_recursive_performance(self):
|
||||
"""Test performance comparison between hybrid and recursive detection"""
|
||||
# Create test schemas of different types and sizes
|
||||
test_cases = [
|
||||
# Case 1: Small schema without refs (most common case)
|
||||
{
|
||||
"name": "small_no_refs",
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}, "value": {"type": "number"}}},
|
||||
"expected": False,
|
||||
},
|
||||
# Case 2: Medium schema without refs
|
||||
{
|
||||
"name": "medium_no_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"field_{i}": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
"items": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
for i in range(20)
|
||||
},
|
||||
},
|
||||
"expected": False,
|
||||
},
|
||||
# Case 3: Large schema without refs
|
||||
{"name": "large_no_refs", "schema": {"type": "object", "properties": {}}, "expected": False},
|
||||
# Case 4: Schema with Dify refs
|
||||
{
|
||||
"name": "with_dify_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
"data": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"expected": True,
|
||||
},
|
||||
# Case 5: Schema with non-Dify refs
|
||||
{
|
||||
"name": "with_external_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"external": {"$ref": "https://example.com/schema.json"}, "data": {"type": "string"}},
|
||||
},
|
||||
"expected": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Add deep nesting to large schema
|
||||
current = test_cases[2]["schema"]
|
||||
for i in range(50):
|
||||
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
|
||||
current = current["properties"][f"level_{i}"]
|
||||
|
||||
# Performance comparison
|
||||
for test_case in test_cases:
|
||||
schema = test_case["schema"]
|
||||
expected = test_case["expected"]
|
||||
name = test_case["name"]
|
||||
|
||||
# Test correctness first
|
||||
assert _has_dify_refs_hybrid(schema) == expected
|
||||
assert _has_dify_refs_recursive(schema) == expected
|
||||
|
||||
# Measure hybrid performance
|
||||
hybrid_times = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
result_hybrid = _has_dify_refs_hybrid(schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
hybrid_times.append(elapsed)
|
||||
|
||||
# Measure recursive performance
|
||||
recursive_times = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
result_recursive = _has_dify_refs_recursive(schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
recursive_times.append(elapsed)
|
||||
|
||||
avg_hybrid = sum(hybrid_times) / len(hybrid_times)
|
||||
avg_recursive = sum(recursive_times) / len(recursive_times)
|
||||
|
||||
print(f"{name}: hybrid={avg_hybrid:.6f}s, recursive={avg_recursive:.6f}s")
|
||||
|
||||
# Results should be identical
|
||||
assert result_hybrid == result_recursive == expected
|
||||
|
||||
# For schemas without refs, hybrid should be competitive or better
|
||||
if not expected: # No refs case
|
||||
# Hybrid might be slightly slower due to JSON serialization overhead,
|
||||
# but should not be dramatically worse
|
||||
assert avg_hybrid < avg_recursive * 5 # At most 5x slower
|
||||
|
||||
def test_string_matching_edge_cases(self):
|
||||
"""Test edge cases for string-based detection"""
|
||||
# Case 1: False positive potential - $ref in description
|
||||
schema_false_positive = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {"type": "string", "description": "This field explains how $ref works in JSON Schema"}
|
||||
},
|
||||
}
|
||||
|
||||
# Both methods should return False
|
||||
assert not _has_dify_refs_hybrid(schema_false_positive)
|
||||
assert not _has_dify_refs_recursive(schema_false_positive)
|
||||
|
||||
# Case 2: Complex URL patterns
|
||||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dify_url": {"type": "string", "default": "https://dify.ai/schemas/info"},
|
||||
"actual_ref": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Both methods should return True (due to actual_ref)
|
||||
assert _has_dify_refs_hybrid(complex_schema)
|
||||
assert _has_dify_refs_recursive(complex_schema)
|
||||
|
||||
# Case 3: Non-JSON serializable objects (should fall back to recursive)
|
||||
import datetime
|
||||
|
||||
non_serializable = {
|
||||
"type": "object",
|
||||
"timestamp": datetime.datetime.now(),
|
||||
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
}
|
||||
|
||||
# Hybrid should fall back to recursive and still work
|
||||
assert _has_dify_refs_hybrid(non_serializable)
|
||||
assert _has_dify_refs_recursive(non_serializable)
|
||||
@ -15,7 +15,7 @@ class FakeResponse:
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.content = content
|
||||
self.text = text if text else content.decode("utf-8", errors="ignore")
|
||||
self.text = text or content.decode("utf-8", errors="ignore")
|
||||
|
||||
|
||||
# ---------------------------
|
||||
|
||||
@ -17,7 +17,6 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
output_schema=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
|
||||
@ -37,7 +37,7 @@ from core.variables.variables import (
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
@ -129,7 +129,6 @@ class TestSegmentDumpAndLoad:
|
||||
"""Test basic segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
@ -137,7 +136,6 @@ class TestSegmentDumpAndLoad:
|
||||
"""Test number segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
@ -145,7 +143,6 @@ class TestSegmentDumpAndLoad:
|
||||
"""Test variable serialization compatibility"""
|
||||
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
restored = _Variables.model_validate_json(json)
|
||||
assert restored == model
|
||||
|
||||
|
||||
@ -0,0 +1,97 @@
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class TestGraphRuntimeState:
|
||||
def test_property_getters_and_setters(self):
|
||||
# FIXME(-LAN-): Mock VariablePool if needed
|
||||
variable_pool = VariablePool()
|
||||
start_time = time()
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time)
|
||||
|
||||
# Test variable_pool property (read-only)
|
||||
assert state.variable_pool == variable_pool
|
||||
|
||||
# Test start_at property
|
||||
assert state.start_at == start_time
|
||||
new_time = time() + 100
|
||||
state.start_at = new_time
|
||||
assert state.start_at == new_time
|
||||
|
||||
# Test total_tokens property
|
||||
assert state.total_tokens == 0
|
||||
state.total_tokens = 100
|
||||
assert state.total_tokens == 100
|
||||
|
||||
# Test node_run_steps property
|
||||
assert state.node_run_steps == 0
|
||||
state.node_run_steps = 5
|
||||
assert state.node_run_steps == 5
|
||||
|
||||
def test_outputs_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting outputs returns a copy
|
||||
outputs1 = state.outputs
|
||||
outputs2 = state.outputs
|
||||
assert outputs1 == outputs2
|
||||
assert outputs1 is not outputs2 # Different objects
|
||||
|
||||
# Test that modifying retrieved outputs doesn't affect internal state
|
||||
outputs = state.outputs
|
||||
outputs["test"] = "value"
|
||||
assert "test" not in state.outputs
|
||||
|
||||
# Test set_output method
|
||||
state.set_output("key1", "value1")
|
||||
assert state.get_output("key1") == "value1"
|
||||
|
||||
# Test update_outputs method
|
||||
state.update_outputs({"key2": "value2", "key3": "value3"})
|
||||
assert state.get_output("key2") == "value2"
|
||||
assert state.get_output("key3") == "value3"
|
||||
|
||||
def test_llm_usage_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting llm_usage returns a copy
|
||||
usage1 = state.llm_usage
|
||||
usage2 = state.llm_usage
|
||||
assert usage1 is not usage2 # Different objects
|
||||
|
||||
def test_type_validation(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test total_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.total_tokens = -1
|
||||
|
||||
# Test node_run_steps validation
|
||||
with pytest.raises(ValueError):
|
||||
state.node_run_steps = -1
|
||||
|
||||
def test_helper_methods(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test increment_node_run_steps
|
||||
initial_steps = state.node_run_steps
|
||||
state.increment_node_run_steps()
|
||||
assert state.node_run_steps == initial_steps + 1
|
||||
|
||||
# Test add_tokens
|
||||
initial_tokens = state.total_tokens
|
||||
state.add_tokens(50)
|
||||
assert state.total_tokens == initial_tokens + 50
|
||||
|
||||
# Test add_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.add_tokens(-1)
|
||||
87
api/tests/unit_tests/core/workflow/entities/test_template.py
Normal file
87
api/tests/unit_tests/core/workflow/entities/test_template.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Tests for template module."""
|
||||
|
||||
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
|
||||
|
||||
|
||||
class TestTemplate:
|
||||
"""Test Template class functionality."""
|
||||
|
||||
def test_from_answer_template_simple(self):
|
||||
"""Test parsing a simple answer template."""
|
||||
template_str = "Hello, {{#node1.name#}}!"
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert len(template.segments) == 3
|
||||
assert isinstance(template.segments[0], TextSegment)
|
||||
assert template.segments[0].text == "Hello, "
|
||||
assert isinstance(template.segments[1], VariableSegment)
|
||||
assert template.segments[1].selector == ["node1", "name"]
|
||||
assert isinstance(template.segments[2], TextSegment)
|
||||
assert template.segments[2].text == "!"
|
||||
|
||||
def test_from_answer_template_multiple_vars(self):
|
||||
"""Test parsing an answer template with multiple variables."""
|
||||
template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}."
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert len(template.segments) == 5
|
||||
assert isinstance(template.segments[0], TextSegment)
|
||||
assert template.segments[0].text == "Hello "
|
||||
assert isinstance(template.segments[1], VariableSegment)
|
||||
assert template.segments[1].selector == ["node1", "name"]
|
||||
assert isinstance(template.segments[2], TextSegment)
|
||||
assert template.segments[2].text == ", your age is "
|
||||
assert isinstance(template.segments[3], VariableSegment)
|
||||
assert template.segments[3].selector == ["node2", "age"]
|
||||
assert isinstance(template.segments[4], TextSegment)
|
||||
assert template.segments[4].text == "."
|
||||
|
||||
def test_from_answer_template_no_vars(self):
|
||||
"""Test parsing an answer template with no variables."""
|
||||
template_str = "Hello, world!"
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert len(template.segments) == 1
|
||||
assert isinstance(template.segments[0], TextSegment)
|
||||
assert template.segments[0].text == "Hello, world!"
|
||||
|
||||
def test_from_end_outputs_single(self):
|
||||
"""Test creating template from End node outputs with single variable."""
|
||||
outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}]
|
||||
template = Template.from_end_outputs(outputs_config)
|
||||
|
||||
assert len(template.segments) == 1
|
||||
assert isinstance(template.segments[0], VariableSegment)
|
||||
assert template.segments[0].selector == ["node1", "text"]
|
||||
|
||||
def test_from_end_outputs_multiple(self):
|
||||
"""Test creating template from End node outputs with multiple variables."""
|
||||
outputs_config = [
|
||||
{"variable": "text", "value_selector": ["node1", "text"]},
|
||||
{"variable": "result", "value_selector": ["node2", "result"]},
|
||||
]
|
||||
template = Template.from_end_outputs(outputs_config)
|
||||
|
||||
assert len(template.segments) == 3
|
||||
assert isinstance(template.segments[0], VariableSegment)
|
||||
assert template.segments[0].selector == ["node1", "text"]
|
||||
assert template.segments[0].variable_name == "text"
|
||||
assert isinstance(template.segments[1], TextSegment)
|
||||
assert template.segments[1].text == "\n"
|
||||
assert isinstance(template.segments[2], VariableSegment)
|
||||
assert template.segments[2].selector == ["node2", "result"]
|
||||
assert template.segments[2].variable_name == "result"
|
||||
|
||||
def test_from_end_outputs_empty(self):
|
||||
"""Test creating template from empty End node outputs."""
|
||||
outputs_config = []
|
||||
template = Template.from_end_outputs(outputs_config)
|
||||
|
||||
assert len(template.segments) == 0
|
||||
|
||||
def test_template_str_representation(self):
|
||||
"""Test string representation of template."""
|
||||
template_str = "Hello, {{#node1.name#}}!"
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert str(template) == template_str
|
||||
@ -0,0 +1,225 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionProcessDataTruncation:
|
||||
"""Test process_data truncation functionality in WorkflowNodeExecution domain model."""
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution instance for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def test_initial_process_data_truncated_state(self):
|
||||
"""Test that process_data_truncated returns False initially."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_truncated_process_data() is None
|
||||
|
||||
def test_set_and_get_truncated_process_data(self):
|
||||
"""Test setting and getting truncated process_data."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
test_truncated_data = {"truncated": True, "key": "value"}
|
||||
|
||||
execution.set_truncated_process_data(test_truncated_data)
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_truncated_data
|
||||
|
||||
def test_set_truncated_process_data_to_none(self):
|
||||
"""Test setting truncated process_data to None."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
# First set some data
|
||||
execution.set_truncated_process_data({"key": "value"})
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
# Then set to None
|
||||
execution.set_truncated_process_data(None)
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_truncated_process_data() is None
|
||||
|
||||
def test_get_response_process_data_with_no_truncation(self):
|
||||
"""Test get_response_process_data when no truncation is set."""
|
||||
original_data = {"original": True, "data": "value"}
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
assert response_data == original_data
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_get_response_process_data_with_truncation(self):
|
||||
"""Test get_response_process_data when truncation is set."""
|
||||
original_data = {"original": True, "large_data": "x" * 10000}
|
||||
truncated_data = {"original": True, "large_data": "[TRUNCATED]"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
execution.set_truncated_process_data(truncated_data)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
# Should return truncated data, not original
|
||||
assert response_data == truncated_data
|
||||
assert response_data != original_data
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_get_response_process_data_with_none_process_data(self):
|
||||
"""Test get_response_process_data when process_data is None."""
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
assert response_data is None
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_consistency_with_inputs_outputs_pattern(self):
|
||||
"""Test that process_data truncation follows the same pattern as inputs/outputs."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
# Test that all truncation methods exist and behave consistently
|
||||
test_data = {"test": "data"}
|
||||
|
||||
# Test inputs truncation
|
||||
execution.set_truncated_inputs(test_data)
|
||||
assert execution.inputs_truncated is True
|
||||
assert execution.get_truncated_inputs() == test_data
|
||||
|
||||
# Test outputs truncation
|
||||
execution.set_truncated_outputs(test_data)
|
||||
assert execution.outputs_truncated is True
|
||||
assert execution.get_truncated_outputs() == test_data
|
||||
|
||||
# Test process_data truncation
|
||||
execution.set_truncated_process_data(test_data)
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_data",
|
||||
[
|
||||
{"simple": "value"},
|
||||
{"nested": {"key": "value"}},
|
||||
{"list": [1, 2, 3]},
|
||||
{"mixed": {"string": "value", "number": 42, "list": [1, 2]}},
|
||||
{}, # empty dict
|
||||
],
|
||||
)
|
||||
def test_truncated_process_data_with_various_data_types(self, test_data):
|
||||
"""Test that truncated process_data works with various data types."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
execution.set_truncated_process_data(test_data)
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_data
|
||||
assert execution.get_response_process_data() == test_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessDataScenario:
|
||||
"""Test scenario data for process_data functionality."""
|
||||
|
||||
name: str
|
||||
original_data: dict[str, Any] | None
|
||||
truncated_data: dict[str, Any] | None
|
||||
expected_truncated_flag: bool
|
||||
expected_response_data: dict[str, Any] | None
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionProcessDataScenarios:
|
||||
"""Test various scenarios for process_data handling."""
|
||||
|
||||
def get_process_data_scenarios(self) -> list[ProcessDataScenario]:
|
||||
"""Create test scenarios for process_data functionality."""
|
||||
return [
|
||||
ProcessDataScenario(
|
||||
name="no_process_data",
|
||||
original_data=None,
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data=None,
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="process_data_without_truncation",
|
||||
original_data={"small": "data"},
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data={"small": "data"},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="process_data_with_truncation",
|
||||
original_data={"large": "x" * 10000, "metadata": "info"},
|
||||
truncated_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_truncated_flag=True,
|
||||
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="empty_process_data",
|
||||
original_data={},
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data={},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="complex_nested_data_with_truncation",
|
||||
original_data={
|
||||
"config": {"setting": "value"},
|
||||
"logs": ["log1", "log2"] * 1000, # Large list
|
||||
"status": "running",
|
||||
},
|
||||
truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"},
|
||||
expected_truncated_flag=True,
|
||||
expected_response_data={
|
||||
"config": {"setting": "value"},
|
||||
"logs": "[TRUNCATED: 2000 items]",
|
||||
"status": "running",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_scenarios(None),
|
||||
ids=[scenario.name for scenario in get_process_data_scenarios(None)],
|
||||
)
|
||||
def test_process_data_scenarios(self, scenario: ProcessDataScenario):
|
||||
"""Test various process_data scenarios."""
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_data)
|
||||
|
||||
assert execution.process_data_truncated == scenario.expected_truncated_flag
|
||||
assert execution.get_response_process_data() == scenario.expected_response_data
|
||||
281
api/tests/unit_tests/core/workflow/graph/test_graph.py
Normal file
281
api/tests/unit_tests/core/workflow/graph/test_graph.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""Unit tests for Graph class methods."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.graph.edge import Edge
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node:
|
||||
"""Create a mock node for testing."""
|
||||
node = Mock(spec=Node)
|
||||
node.id = node_id
|
||||
node.execution_type = execution_type
|
||||
node.state = state
|
||||
node.node_type = NodeType.START
|
||||
return node
|
||||
|
||||
|
||||
class TestMarkInactiveRootBranches:
|
||||
"""Test cases for _mark_inactive_root_branches method."""
|
||||
|
||||
def test_single_root_no_marking(self):
|
||||
"""Test that single root graph doesn't mark anything as skipped."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"]}
|
||||
out_edges = {"root1": ["edge1"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_multiple_roots_mark_inactive(self):
|
||||
"""Test marking inactive root branches with multiple root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "root2": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
|
||||
def test_shared_downstream_node(self):
|
||||
"""Test that shared downstream nodes are not skipped if at least one path is active."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"shared": ["edge3", "edge4"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"child1": ["edge3"],
|
||||
"child2": ["edge4"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
|
||||
def test_deep_branch_marking(self):
|
||||
"""Test marking deep branches with multiple levels."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE),
|
||||
"level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE),
|
||||
"level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE),
|
||||
"level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE),
|
||||
"level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"level1_a": ["edge1"],
|
||||
"level1_b": ["edge2"],
|
||||
"level2_a": ["edge3"],
|
||||
"level2_b": ["edge4"],
|
||||
"level3": ["edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"level1_a": ["edge3"],
|
||||
"level1_b": ["edge4"],
|
||||
"level2_b": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["level1_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level1_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level2_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level2_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
|
||||
def test_non_root_execution_type(self):
|
||||
"""Test that nodes with non-ROOT execution type are not treated as root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "non_root": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_empty_graph(self):
|
||||
"""Test handling of empty graph structures."""
|
||||
nodes = {}
|
||||
edges = {}
|
||||
in_edges = {}
|
||||
out_edges = {}
|
||||
|
||||
# Should not raise any errors
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent")
|
||||
|
||||
def test_three_roots_mark_two_inactive(self):
|
||||
"""Test with three root nodes where two should be marked inactive."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"child3": ["edge3"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2")
|
||||
|
||||
assert nodes["root1"].state == NodeState.SKIPPED
|
||||
assert nodes["root2"].state == NodeState.UNKNOWN # Active root
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.SKIPPED
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert nodes["child3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.SKIPPED
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
|
||||
def test_convergent_paths(self):
|
||||
"""Test convergent paths where multiple inactive branches lead to same node."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE),
|
||||
"mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE),
|
||||
"convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"mid1": ["edge1"],
|
||||
"mid2": ["edge2"],
|
||||
"convergent": ["edge3", "edge4", "edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
"mid1": ["edge4"],
|
||||
"mid2": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["mid1"].state == NodeState.UNKNOWN
|
||||
assert nodes["mid2"].state == NodeState.SKIPPED
|
||||
assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
assert edges["edge4"].state == NodeState.UNKNOWN
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
487
api/tests/unit_tests/core/workflow/graph_engine/README.md
Normal file
487
api/tests/unit_tests/core/workflow/graph_engine/README.md
Normal file
@ -0,0 +1,487 @@
|
||||
# Graph Engine Testing Framework
|
||||
|
||||
## Overview
|
||||
|
||||
This directory contains a comprehensive testing framework for the Graph Engine, including:
|
||||
|
||||
1. **TableTestRunner** - Advanced table-driven test framework for workflow testing
|
||||
1. **Auto-Mock System** - Powerful mocking framework for testing without external dependencies
|
||||
|
||||
## TableTestRunner Framework
|
||||
|
||||
The TableTestRunner (`test_table_runner.py`) provides a robust table-driven testing framework for GraphEngine workflows.
|
||||
|
||||
### Features
|
||||
|
||||
- **Table-driven testing** - Define test cases as structured data
|
||||
- **Parallel test execution** - Run tests concurrently for faster execution
|
||||
- **Property-based testing** - Integration with Hypothesis for fuzzing
|
||||
- **Event sequence validation** - Verify correct event ordering
|
||||
- **Mock configuration** - Seamless integration with the auto-mock system
|
||||
- **Performance metrics** - Track execution times and bottlenecks
|
||||
- **Detailed error reporting** - Comprehensive failure diagnostics
|
||||
- **Test tagging** - Organize and filter tests by tags
|
||||
- **Retry mechanism** - Handle flaky tests gracefully
|
||||
- **Custom validators** - Define custom validation logic
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
# Create test runner
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Define test case
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="simple_workflow",
|
||||
inputs={"query": "Hello"},
|
||||
expected_outputs={"result": "World"},
|
||||
description="Basic workflow test",
|
||||
)
|
||||
|
||||
# Run single test
|
||||
result = runner.run_test_case(test_case)
|
||||
assert result.success
|
||||
```
|
||||
|
||||
### Advanced Features
|
||||
|
||||
#### Parallel Execution
|
||||
|
||||
```python
|
||||
runner = TableTestRunner(max_workers=8)
|
||||
|
||||
test_cases = [
|
||||
WorkflowTestCase(...),
|
||||
WorkflowTestCase(...),
|
||||
# ... more test cases
|
||||
]
|
||||
|
||||
# Run tests in parallel
|
||||
suite_result = runner.run_table_tests(
|
||||
test_cases,
|
||||
parallel=True,
|
||||
fail_fast=False
|
||||
)
|
||||
|
||||
print(f"Success rate: {suite_result.success_rate:.1f}%")
|
||||
```
|
||||
|
||||
#### Test Tagging and Filtering
|
||||
|
||||
```python
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="workflow",
|
||||
inputs={},
|
||||
expected_outputs={},
|
||||
tags=["smoke", "critical"],
|
||||
)
|
||||
|
||||
# Run only tests with specific tags
|
||||
suite_result = runner.run_table_tests(
|
||||
test_cases,
|
||||
tags_filter=["smoke"]
|
||||
)
|
||||
```
|
||||
|
||||
#### Retry Mechanism
|
||||
|
||||
```python
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="flaky_workflow",
|
||||
inputs={},
|
||||
expected_outputs={},
|
||||
retry_count=2, # Retry up to 2 times on failure
|
||||
)
|
||||
```
|
||||
|
||||
#### Custom Validators
|
||||
|
||||
```python
|
||||
def custom_validator(outputs: dict) -> bool:
|
||||
# Custom validation logic
|
||||
return "error" not in outputs.get("status", "")
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="workflow",
|
||||
inputs={},
|
||||
expected_outputs={"status": "success"},
|
||||
custom_validator=custom_validator,
|
||||
)
|
||||
```
|
||||
|
||||
#### Event Sequence Validation
|
||||
|
||||
```python
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="workflow",
|
||||
inputs={},
|
||||
expected_outputs={},
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Test Suite Reports
|
||||
|
||||
```python
|
||||
# Run test suite
|
||||
suite_result = runner.run_table_tests(test_cases)
|
||||
|
||||
# Generate detailed report
|
||||
report = runner.generate_report(suite_result)
|
||||
print(report)
|
||||
|
||||
# Access specific results
|
||||
failed_results = suite_result.get_failed_results()
|
||||
for result in failed_results:
|
||||
print(f"Failed: {result.test_case.description}")
|
||||
print(f" Error: {result.error}")
|
||||
```
|
||||
|
||||
### Performance Testing
|
||||
|
||||
```python
|
||||
# Enable logging for performance insights
|
||||
runner = TableTestRunner(
|
||||
enable_logging=True,
|
||||
log_level="DEBUG"
|
||||
)
|
||||
|
||||
# Run tests and analyze performance
|
||||
suite_result = runner.run_table_tests(test_cases)
|
||||
|
||||
# Get slowest tests
|
||||
sorted_results = sorted(
|
||||
suite_result.results,
|
||||
key=lambda r: r.execution_time,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
print("Slowest tests:")
|
||||
for result in sorted_results[:5]:
|
||||
print(f" {result.test_case.description}: {result.execution_time:.2f}s")
|
||||
```
|
||||
|
||||
## Integration: TableTestRunner + Auto-Mock System
|
||||
|
||||
The TableTestRunner seamlessly integrates with the auto-mock system for comprehensive workflow testing:
|
||||
|
||||
```python
|
||||
from test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
from test_mock_config import MockConfigBuilder
|
||||
|
||||
# Configure mocks
|
||||
mock_config = (MockConfigBuilder()
|
||||
.with_llm_response("Mocked LLM response")
|
||||
.with_tool_response({"result": "mocked"})
|
||||
.with_delays(True) # Simulate realistic delays
|
||||
.build())
|
||||
|
||||
# Create test case with mocking
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="complex_workflow",
|
||||
inputs={"query": "test"},
|
||||
expected_outputs={"answer": "Mocked LLM response"},
|
||||
use_auto_mock=True, # Enable auto-mocking
|
||||
mock_config=mock_config,
|
||||
description="Test with mocked services",
|
||||
)
|
||||
|
||||
# Run test
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(test_case)
|
||||
```
|
||||
|
||||
## Auto-Mock System
|
||||
|
||||
The auto-mock system provides a powerful framework for testing workflows that contain nodes requiring third-party services (LLM, APIs, tools, etc.) without making actual external calls. This enables:
|
||||
|
||||
- **Fast test execution** - No network latency or API rate limits
|
||||
- **Deterministic results** - Consistent outputs for reliable testing
|
||||
- **Cost savings** - No API usage charges during testing
|
||||
- **Offline testing** - Tests can run without internet connectivity
|
||||
- **Error simulation** - Test error handling without triggering real failures
|
||||
|
||||
## Architecture
|
||||
|
||||
The auto-mock system consists of three main components:
|
||||
|
||||
### 1. MockNodeFactory (`test_mock_factory.py`)
|
||||
|
||||
- Extends `DifyNodeFactory` to intercept node creation
|
||||
- Automatically detects nodes requiring third-party services
|
||||
- Returns mock node implementations instead of real ones
|
||||
- Supports registration of custom mock implementations
|
||||
|
||||
### 2. Mock Node Implementations (`test_mock_nodes.py`)
|
||||
|
||||
- `MockLLMNode` - Mocks LLM API calls (OpenAI, Anthropic, etc.)
|
||||
- `MockAgentNode` - Mocks agent execution
|
||||
- `MockToolNode` - Mocks tool invocations
|
||||
- `MockKnowledgeRetrievalNode` - Mocks knowledge base queries
|
||||
- `MockHttpRequestNode` - Mocks HTTP requests
|
||||
- `MockParameterExtractorNode` - Mocks parameter extraction
|
||||
- `MockDocumentExtractorNode` - Mocks document processing
|
||||
- `MockQuestionClassifierNode` - Mocks question classification
|
||||
|
||||
### 3. Mock Configuration (`test_mock_config.py`)
|
||||
|
||||
- `MockConfig` - Global configuration for mock behavior
|
||||
- `NodeMockConfig` - Node-specific mock configuration
|
||||
- `MockConfigBuilder` - Fluent interface for building configurations
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
```python
|
||||
from test_graph_engine import TableTestRunner, WorkflowTestCase
|
||||
from test_mock_config import MockConfigBuilder
|
||||
|
||||
# Create test runner
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Configure mock responses
|
||||
mock_config = (MockConfigBuilder()
|
||||
.with_llm_response("Mocked LLM response")
|
||||
.build())
|
||||
|
||||
# Define test case
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="llm-simple",
|
||||
inputs={"query": "Hello"},
|
||||
expected_outputs={"answer": "Mocked LLM response"},
|
||||
use_auto_mock=True, # Enable auto-mocking
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
# Run test
|
||||
result = runner.run_test_case(test_case)
|
||||
assert result.success
|
||||
```
|
||||
|
||||
### Custom Node Outputs
|
||||
|
||||
```python
|
||||
# Configure specific outputs for individual nodes
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs("llm_node_123", {
|
||||
"text": "Custom response for this specific node",
|
||||
"usage": {"total_tokens": 50},
|
||||
"finish_reason": "stop",
|
||||
})
|
||||
```
|
||||
|
||||
### Error Simulation
|
||||
|
||||
```python
|
||||
# Simulate node failures for error handling tests
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_error("http_node", "Connection timeout")
|
||||
```
|
||||
|
||||
### Simulated Delays
|
||||
|
||||
```python
|
||||
# Add realistic execution delays
|
||||
from test_mock_config import NodeMockConfig
|
||||
|
||||
node_config = NodeMockConfig(
|
||||
node_id="llm_node",
|
||||
outputs={"text": "Response"},
|
||||
delay=1.5, # 1.5 second delay
|
||||
)
|
||||
mock_config.set_node_config("llm_node", node_config)
|
||||
```
|
||||
|
||||
### Custom Handlers
|
||||
|
||||
```python
|
||||
# Define custom logic for mock outputs
|
||||
def custom_handler(node):
|
||||
# Access node state and return dynamic outputs
|
||||
return {
|
||||
"text": f"Processed: {node.graph_runtime_state.variable_pool.get('query')}",
|
||||
}
|
||||
|
||||
node_config = NodeMockConfig(
|
||||
node_id="llm_node",
|
||||
custom_handler=custom_handler,
|
||||
)
|
||||
```
|
||||
|
||||
## Node Types Automatically Mocked
|
||||
|
||||
The following node types are automatically mocked when `use_auto_mock=True`:
|
||||
|
||||
- `LLM` - Language model nodes
|
||||
- `AGENT` - Agent execution nodes
|
||||
- `TOOL` - Tool invocation nodes
|
||||
- `KNOWLEDGE_RETRIEVAL` - Knowledge base query nodes
|
||||
- `HTTP_REQUEST` - HTTP request nodes
|
||||
- `PARAMETER_EXTRACTOR` - Parameter extraction nodes
|
||||
- `DOCUMENT_EXTRACTOR` - Document processing nodes
|
||||
- `QUESTION_CLASSIFIER` - Question classification nodes
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Registering Custom Mock Implementations
|
||||
|
||||
```python
|
||||
from test_mock_factory import MockNodeFactory
|
||||
|
||||
# Create custom mock implementation
|
||||
class CustomMockNode(BaseNode):
|
||||
def _run(self):
|
||||
# Custom mock logic
|
||||
pass
|
||||
|
||||
# Register for a specific node type
|
||||
factory = MockNodeFactory(...)
|
||||
factory.register_mock_node_type(NodeType.CUSTOM, CustomMockNode)
|
||||
```
|
||||
|
||||
### Default Configurations by Node Type
|
||||
|
||||
```python
|
||||
# Set defaults for all nodes of a specific type
|
||||
mock_config.set_default_config(NodeType.LLM, {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
})
|
||||
```
|
||||
|
||||
### MockConfigBuilder Fluent API
|
||||
|
||||
```python
|
||||
config = (MockConfigBuilder()
|
||||
.with_llm_response("LLM response")
|
||||
.with_agent_response("Agent response")
|
||||
.with_tool_response({"result": "data"})
|
||||
.with_retrieval_response("Retrieved content")
|
||||
.with_http_response({"status_code": 200, "body": "{}"})
|
||||
.with_node_output("node_id", {"output": "value"})
|
||||
.with_node_error("error_node", "Error message")
|
||||
.with_delays(True)
|
||||
.build())
|
||||
```
|
||||
|
||||
## Testing Workflows
|
||||
|
||||
### 1. Create Workflow Fixture
|
||||
|
||||
Create a YAML fixture file in `api/tests/fixtures/workflow/` directory defining your workflow graph.
|
||||
|
||||
### 2. Configure Mocks
|
||||
|
||||
Set up mock configurations for nodes that need third-party services.
|
||||
|
||||
### 3. Define Test Cases
|
||||
|
||||
Create `WorkflowTestCase` instances with inputs, expected outputs, and mock config.
|
||||
|
||||
### 4. Run Tests
|
||||
|
||||
Use `TableTestRunner` to execute test cases and validate results.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use descriptive mock responses** - Make it clear in outputs that they are mocked
|
||||
1. **Test both success and failure paths** - Use error simulation to test error handling
|
||||
1. **Keep mock configs close to tests** - Define mocks in the same test file for clarity
|
||||
1. **Use custom handlers sparingly** - Only when dynamic behavior is needed
|
||||
1. **Document mock behavior** - Comment why specific mock values are chosen
|
||||
1. **Validate mock accuracy** - Ensure mocks reflect real service behavior
|
||||
|
||||
## Examples
|
||||
|
||||
See `test_mock_example.py` for comprehensive examples including:
|
||||
|
||||
- Basic LLM workflow testing
|
||||
- Custom node outputs
|
||||
- HTTP and tool workflow testing
|
||||
- Error simulation
|
||||
- Performance testing with delays
|
||||
|
||||
## Running Tests
|
||||
|
||||
### TableTestRunner Tests
|
||||
|
||||
```bash
|
||||
# Run graph engine tests (includes property-based tests)
|
||||
uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
|
||||
|
||||
# Run with specific test patterns
|
||||
uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py -k "test_echo"
|
||||
|
||||
# Run with verbose output
|
||||
uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py -v
|
||||
```
|
||||
|
||||
### Mock System Tests
|
||||
|
||||
```bash
|
||||
# Run auto-mock system tests
|
||||
uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py
|
||||
|
||||
# Run examples
|
||||
uv run python api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py
|
||||
|
||||
# Run simple validation
|
||||
uv run python api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py
|
||||
```
|
||||
|
||||
### All Tests
|
||||
|
||||
```bash
|
||||
# Run all graph engine tests
|
||||
uv run pytest api/tests/unit_tests/core/workflow/graph_engine/
|
||||
|
||||
# Run with coverage
|
||||
uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ --cov=core.workflow.graph_engine
|
||||
|
||||
# Run in parallel
|
||||
uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ -n auto
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Mock not being applied
|
||||
|
||||
- Ensure `use_auto_mock=True` in `WorkflowTestCase`
|
||||
- Verify node ID matches in mock config
|
||||
- Check that node type is in the auto-mock list
|
||||
|
||||
### Issue: Unexpected outputs
|
||||
|
||||
- Debug by printing `result.actual_outputs`
|
||||
- Check if custom handler is overriding expected outputs
|
||||
- Verify mock config is properly built
|
||||
|
||||
### Issue: Import errors
|
||||
|
||||
- Ensure all mock modules are in the correct path
|
||||
- Check that required dependencies are installed
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements to the auto-mock system:
|
||||
|
||||
1. **Recording and playback** - Record real API responses for replay in tests
|
||||
1. **Mock templates** - Pre-defined mock configurations for common scenarios
|
||||
1. **Async support** - Better support for async node execution
|
||||
1. **Mock validation** - Validate mock outputs against node schemas
|
||||
1. **Performance profiling** - Built-in performance metrics for mocked workflows
|
||||
@ -0,0 +1,208 @@
|
||||
"""Tests for Redis command channel implementation."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
|
||||
class TestRedisChannel:
|
||||
"""Test suite for RedisChannel functionality."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test RedisChannel initialization."""
|
||||
mock_redis = MagicMock()
|
||||
channel_key = "test:channel:key"
|
||||
ttl = 7200
|
||||
|
||||
channel = RedisChannel(mock_redis, channel_key, ttl)
|
||||
|
||||
assert channel._redis == mock_redis
|
||||
assert channel._key == channel_key
|
||||
assert channel._command_ttl == ttl
|
||||
|
||||
def test_init_default_ttl(self):
|
||||
"""Test RedisChannel initialization with default TTL."""
|
||||
mock_redis = MagicMock()
|
||||
channel_key = "test:channel:key"
|
||||
|
||||
channel = RedisChannel(mock_redis, channel_key)
|
||||
|
||||
assert channel._command_ttl == 3600 # Default TTL
|
||||
|
||||
def test_send_command(self):
|
||||
"""Test sending a command to Redis."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key", 3600)
|
||||
|
||||
# Create a test command
|
||||
command = GraphEngineCommand(command_type=CommandType.ABORT)
|
||||
|
||||
# Send the command
|
||||
channel.send_command(command)
|
||||
|
||||
# Verify pipeline was used
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
|
||||
# Verify rpush was called with correct data
|
||||
expected_json = json.dumps(command.model_dump())
|
||||
mock_pipe.rpush.assert_called_once_with("test:key", expected_json)
|
||||
|
||||
# Verify expire was set
|
||||
mock_pipe.expire.assert_called_once_with("test:key", 3600)
|
||||
|
||||
# Verify execute was called
|
||||
mock_pipe.execute.assert_called_once()
|
||||
|
||||
def test_fetch_commands_empty(self):
|
||||
"""Test fetching commands when Redis list is empty."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Simulate empty list
|
||||
mock_pipe.execute.return_value = [[], 1] # Empty list, delete successful
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert commands == []
|
||||
mock_pipe.lrange.assert_called_once_with("test:key", 0, -1)
|
||||
mock_pipe.delete.assert_called_once_with("test:key")
|
||||
|
||||
def test_fetch_commands_with_abort_command(self):
|
||||
"""Test fetching abort commands from Redis."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Create abort command data
|
||||
abort_command = AbortCommand()
|
||||
command_json = json.dumps(abort_command.model_dump())
|
||||
|
||||
# Simulate Redis returning one command
|
||||
mock_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert len(commands) == 1
|
||||
assert isinstance(commands[0], AbortCommand)
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
|
||||
def test_fetch_commands_multiple(self):
|
||||
"""Test fetching multiple commands from Redis."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Create multiple commands
|
||||
command1 = GraphEngineCommand(command_type=CommandType.ABORT)
|
||||
command2 = AbortCommand()
|
||||
|
||||
command1_json = json.dumps(command1.model_dump())
|
||||
command2_json = json.dumps(command2.model_dump())
|
||||
|
||||
# Simulate Redis returning multiple commands
|
||||
mock_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert len(commands) == 2
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
assert isinstance(commands[1], AbortCommand)
|
||||
|
||||
def test_fetch_commands_skips_invalid_json(self):
|
||||
"""Test that invalid JSON commands are skipped."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Mix valid and invalid JSON
|
||||
valid_command = AbortCommand()
|
||||
valid_json = json.dumps(valid_command.model_dump())
|
||||
invalid_json = b"invalid json {"
|
||||
|
||||
# Simulate Redis returning mixed valid/invalid commands
|
||||
mock_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
# Should only return the valid command
|
||||
assert len(commands) == 1
|
||||
assert isinstance(commands[0], AbortCommand)
|
||||
|
||||
def test_deserialize_command_abort(self):
|
||||
"""Test deserializing an abort command."""
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
abort_data = {"command_type": CommandType.ABORT.value}
|
||||
command = channel._deserialize_command(abort_data)
|
||||
|
||||
assert isinstance(command, AbortCommand)
|
||||
assert command.command_type == CommandType.ABORT
|
||||
|
||||
def test_deserialize_command_generic(self):
|
||||
"""Test deserializing a generic command."""
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
# For now, only ABORT is supported, but test generic handling
|
||||
generic_data = {"command_type": CommandType.ABORT.value}
|
||||
command = channel._deserialize_command(generic_data)
|
||||
|
||||
assert command is not None
|
||||
assert command.command_type == CommandType.ABORT
|
||||
|
||||
def test_deserialize_command_invalid(self):
|
||||
"""Test deserializing invalid command data."""
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
# Missing command_type
|
||||
invalid_data = {"some_field": "value"}
|
||||
command = channel._deserialize_command(invalid_data)
|
||||
|
||||
assert command is None
|
||||
|
||||
def test_deserialize_command_invalid_type(self):
|
||||
"""Test deserializing command with invalid type."""
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
# Invalid command type
|
||||
invalid_data = {"command_type": "INVALID_TYPE"}
|
||||
command = channel._deserialize_command(invalid_data)
|
||||
|
||||
assert command is None
|
||||
|
||||
def test_atomic_fetch_and_clear(self):
|
||||
"""Test that fetch_commands atomically fetches and clears the list."""
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
command = AbortCommand()
|
||||
command_json = json.dumps(command.model_dump())
|
||||
mock_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
|
||||
# First fetch should return the command
|
||||
commands = channel.fetch_commands()
|
||||
assert len(commands) == 1
|
||||
|
||||
# Verify both lrange and delete were called in the pipeline
|
||||
assert mock_pipe.lrange.call_count == 1
|
||||
assert mock_pipe.delete.call_count == 1
|
||||
mock_pipe.lrange.assert_called_with("test:key", 0, -1)
|
||||
mock_pipe.delete.assert_called_with("test:key")
|
||||
@ -1,146 +0,0 @@
|
||||
import time
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def create_test_graph_runtime_state() -> GraphRuntimeState:
|
||||
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
|
||||
# Create a variable pool with system variables
|
||||
system_vars = SystemVariable(
|
||||
user_id="test_user_123",
|
||||
app_id="test_app_456",
|
||||
workflow_id="test_workflow_789",
|
||||
workflow_execution_id="test_execution_001",
|
||||
query="test query",
|
||||
conversation_id="test_conv_123",
|
||||
dialogue_count=5,
|
||||
)
|
||||
variable_pool = VariablePool(system_variables=system_vars)
|
||||
|
||||
# Add some variables to the variable pool
|
||||
variable_pool.add(["test_node", "test_var"], "test_value")
|
||||
variable_pool.add(["another_node", "another_var"], 42)
|
||||
|
||||
# Create LLM usage with realistic values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=150,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=75,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=225,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=1.25,
|
||||
)
|
||||
|
||||
# Create runtime route state with some node states
|
||||
node_run_state = RuntimeRouteState()
|
||||
node_state = node_run_state.create_node_state("test_node_1")
|
||||
node_run_state.add_route(node_state.id, "target_node_id")
|
||||
|
||||
return GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
total_tokens=100,
|
||||
llm_usage=llm_usage,
|
||||
outputs={
|
||||
"string_output": "test result",
|
||||
"int_output": 42,
|
||||
"float_output": 3.14,
|
||||
"list_output": ["item1", "item2", "item3"],
|
||||
"dict_output": {"key1": "value1", "key2": 123},
|
||||
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
|
||||
},
|
||||
node_run_steps=5,
|
||||
node_run_state=node_run_state,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_round_trip_serialization():
|
||||
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
|
||||
# Create a state with non-empty values
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Core test: ensure the round-trip preserves all values
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="python")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="json")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_outputs_field_round_trip():
|
||||
"""Test the problematic outputs field maintains values through round-trip serialization."""
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize and deserialize
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Verify the outputs field specifically maintains its values
|
||||
assert deserialized_state.outputs == original_state.outputs
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_empty_outputs_round_trip():
|
||||
"""Test round-trip serialization with empty outputs field."""
|
||||
variable_pool = VariablePool.empty()
|
||||
original_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
outputs={}, # Empty outputs
|
||||
)
|
||||
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_llm_usage_round_trip():
|
||||
# Create LLM usage with specific decimal values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.0015"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.003"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=2.5,
|
||||
)
|
||||
|
||||
json_data = llm_usage.model_dump_json()
|
||||
deserialized = LLMUsage.model_validate_json(json_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="python")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="json")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
@ -1,401 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
|
||||
|
||||
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
|
||||
class TestRouteNodeStateSerialization:
|
||||
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
|
||||
|
||||
def _test_route_node_state(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
|
||||
node_run_result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"input_key": "input_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
)
|
||||
|
||||
node_state = RouteNodeState(
|
||||
node_id="comprehensive_test_node",
|
||||
start_at=_TEST_DATETIME,
|
||||
finished_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
node_run_result=node_run_result,
|
||||
index=5,
|
||||
paused_at=_TEST_DATETIME,
|
||||
paused_by="user_123",
|
||||
failed_reason="test_reason",
|
||||
)
|
||||
return node_state
|
||||
|
||||
def test_route_node_state_comprehensive_field_validation(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
node_state = self._test_route_node_state()
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Comprehensive validation of all RouteNodeState fields
|
||||
assert serialized["node_id"] == "comprehensive_test_node"
|
||||
assert serialized["status"] == RouteNodeState.Status.SUCCESS
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["finished_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_by"] == "user_123"
|
||||
assert serialized["failed_reason"] == "test_reason"
|
||||
assert serialized["index"] == 5
|
||||
assert "id" in serialized
|
||||
assert isinstance(serialized["id"], str)
|
||||
uuid.UUID(serialized["id"]) # Validate UUID format
|
||||
|
||||
# Validate nested NodeRunResult structure
|
||||
assert serialized["node_run_result"] is not None
|
||||
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
|
||||
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
|
||||
|
||||
def test_route_node_state_minimal_required_fields(self):
|
||||
"""Test RouteNodeState with only required fields, focusing on defaults."""
|
||||
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
|
||||
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Focus on required fields and default values (not re-testing all fields)
|
||||
assert serialized["node_id"] == "minimal_node"
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
|
||||
assert serialized["index"] == 1 # Default index
|
||||
assert serialized["node_run_result"] is None # Default None
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
def test_route_node_state_deserialization_from_dict(self):
|
||||
"""Test RouteNodeState deserialization from dictionary data."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
test_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"id": test_id,
|
||||
"node_id": "deserialized_node",
|
||||
"start_at": test_datetime,
|
||||
"status": "success",
|
||||
"finished_at": test_datetime,
|
||||
"index": 3,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert node_state.id == test_id
|
||||
assert node_state.node_id == "deserialized_node"
|
||||
assert node_state.start_at == test_datetime
|
||||
assert node_state.status == RouteNodeState.Status.SUCCESS
|
||||
assert node_state.finished_at == test_datetime
|
||||
assert node_state.index == 3
|
||||
|
||||
def test_route_node_state_round_trip_consistency(self):
|
||||
node_states = (
|
||||
self._test_route_node_state(),
|
||||
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
|
||||
)
|
||||
for node_state in node_states:
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="python")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="json")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
|
||||
class TestRouteNodeStateEnumSerialization:
|
||||
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
|
||||
|
||||
def test_status_enum_model_dump_behavior(self):
|
||||
"""Test Status enum serialization in model_dump() returns enum objects."""
|
||||
|
||||
for status_enum in RouteNodeState.Status:
|
||||
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
|
||||
serialized = node_state.model_dump(mode="python")
|
||||
assert serialized["status"] == status_enum
|
||||
serialized = node_state.model_dump(mode="json")
|
||||
assert serialized["status"] == status_enum.value
|
||||
|
||||
def test_status_enum_json_serialization_behavior(self):
|
||||
"""Test Status enum serialization in JSON returns string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
enum_to_string_mapping = {
|
||||
RouteNodeState.Status.RUNNING: "running",
|
||||
RouteNodeState.Status.SUCCESS: "success",
|
||||
RouteNodeState.Status.FAILED: "failed",
|
||||
RouteNodeState.Status.PAUSED: "paused",
|
||||
RouteNodeState.Status.EXCEPTION: "exception",
|
||||
}
|
||||
|
||||
for status_enum, expected_string in enum_to_string_mapping.items():
|
||||
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
|
||||
|
||||
json_data = json.loads(node_state.model_dump_json())
|
||||
assert json_data["status"] == expected_string
|
||||
|
||||
def test_status_enum_deserialization_from_string(self):
|
||||
"""Test Status enum deserialization from string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
string_to_enum_mapping = {
|
||||
"running": RouteNodeState.Status.RUNNING,
|
||||
"success": RouteNodeState.Status.SUCCESS,
|
||||
"failed": RouteNodeState.Status.FAILED,
|
||||
"paused": RouteNodeState.Status.PAUSED,
|
||||
"exception": RouteNodeState.Status.EXCEPTION,
|
||||
}
|
||||
|
||||
for status_string, expected_enum in string_to_enum_mapping.items():
|
||||
dict_data = {
|
||||
"node_id": "enum_deserialize_test",
|
||||
"start_at": test_datetime,
|
||||
"status": status_string,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
assert node_state.status == expected_enum
|
||||
|
||||
|
||||
class TestRuntimeRouteStateSerialization:
|
||||
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
|
||||
|
||||
_NODE1_ID = "node_1"
|
||||
_ROUTE_STATE1_ID = str(uuid.uuid4())
|
||||
_NODE2_ID = "node_2"
|
||||
_ROUTE_STATE2_ID = str(uuid.uuid4())
|
||||
_NODE3_ID = "node_3"
|
||||
_ROUTE_STATE3_ID = str(uuid.uuid4())
|
||||
|
||||
def _get_runtime_route_state(self):
|
||||
# Create node states with different configurations
|
||||
node_state_1 = RouteNodeState(
|
||||
id=self._ROUTE_STATE1_ID,
|
||||
node_id=self._NODE1_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
index=1,
|
||||
)
|
||||
node_state_2 = RouteNodeState(
|
||||
id=self._ROUTE_STATE2_ID,
|
||||
node_id=self._NODE2_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
finished_at=_TEST_DATETIME,
|
||||
index=2,
|
||||
)
|
||||
node_state_3 = RouteNodeState(
|
||||
id=self._ROUTE_STATE3_ID,
|
||||
node_id=self._NODE3_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.FAILED,
|
||||
failed_reason="Test failure",
|
||||
index=3,
|
||||
)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
|
||||
node_state_mapping={
|
||||
node_state_1.id: node_state_1,
|
||||
node_state_2.id: node_state_2,
|
||||
node_state_3.id: node_state_3,
|
||||
},
|
||||
)
|
||||
|
||||
return runtime_state
|
||||
|
||||
def test_runtime_route_state_comprehensive_structure_validation(self):
|
||||
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
|
||||
|
||||
runtime_state = self._get_runtime_route_state()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Comprehensive validation of RuntimeRouteState structure
|
||||
assert "routes" in serialized
|
||||
assert "node_state_mapping" in serialized
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
# Validate routes dictionary structure and content
|
||||
assert len(serialized["routes"]) == 2
|
||||
assert self._ROUTE_STATE1_ID in serialized["routes"]
|
||||
assert self._ROUTE_STATE2_ID in serialized["routes"]
|
||||
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
|
||||
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
|
||||
|
||||
# Validate node_state_mapping dictionary structure and content
|
||||
assert len(serialized["node_state_mapping"]) == 3
|
||||
for state_id in [
|
||||
self._ROUTE_STATE1_ID,
|
||||
self._ROUTE_STATE2_ID,
|
||||
self._ROUTE_STATE3_ID,
|
||||
]:
|
||||
assert state_id in serialized["node_state_mapping"]
|
||||
node_data = serialized["node_state_mapping"][state_id]
|
||||
node_state = runtime_state.node_state_mapping[state_id]
|
||||
assert node_data["node_id"] == node_state.node_id
|
||||
assert node_data["status"] == node_state.status
|
||||
assert node_data["index"] == node_state.index
|
||||
|
||||
def test_runtime_route_state_empty_collections(self):
|
||||
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
|
||||
runtime_state = RuntimeRouteState()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Focus on default empty collection behavior
|
||||
assert serialized["routes"] == {}
|
||||
assert serialized["node_state_mapping"] == {}
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
def test_runtime_route_state_json_serialization_structure(self):
|
||||
"""Test RuntimeRouteState JSON serialization structure."""
|
||||
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
|
||||
)
|
||||
|
||||
json_str = runtime_state.model_dump_json()
|
||||
json_data = json.loads(json_str)
|
||||
|
||||
# Focus on JSON structure validation
|
||||
assert isinstance(json_str, str)
|
||||
assert isinstance(json_data, dict)
|
||||
assert "routes" in json_data
|
||||
assert "node_state_mapping" in json_data
|
||||
assert json_data["routes"]["source"] == ["target1", "target2"]
|
||||
assert node_state.id in json_data["node_state_mapping"]
|
||||
|
||||
def test_runtime_route_state_deserialization_from_dict(self):
|
||||
"""Test RuntimeRouteState deserialization from dictionary data."""
|
||||
node_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"routes": {"source_node": ["target_node_1", "target_node_2"]},
|
||||
"node_state_mapping": {
|
||||
node_id: {
|
||||
"id": node_id,
|
||||
"node_id": "test_node",
|
||||
"start_at": _TEST_DATETIME,
|
||||
"status": "running",
|
||||
"index": 1,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
runtime_state = RuntimeRouteState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
|
||||
assert len(runtime_state.node_state_mapping) == 1
|
||||
assert node_id in runtime_state.node_state_mapping
|
||||
|
||||
deserialized_node = runtime_state.node_state_mapping[node_id]
|
||||
assert deserialized_node.node_id == "test_node"
|
||||
assert deserialized_node.status == RouteNodeState.Status.RUNNING
|
||||
assert deserialized_node.index == 1
|
||||
|
||||
def test_runtime_route_state_round_trip_consistency(self):
|
||||
"""Test RuntimeRouteState round-trip serialization consistency."""
|
||||
original = self._get_runtime_route_state()
|
||||
|
||||
# Dictionary round trip
|
||||
dict_data = original.model_dump(mode="python")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
dict_data = original.model_dump(mode="json")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
# JSON round trip
|
||||
json_str = original.model_dump_json()
|
||||
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
|
||||
assert json_reconstructed == original
|
||||
|
||||
|
||||
class TestSerializationEdgeCases:
|
||||
"""Test edge cases and error conditions for serialization/deserialization."""
|
||||
|
||||
def test_invalid_status_deserialization(self):
|
||||
"""Test deserialization with invalid status values."""
|
||||
test_datetime = _TEST_DATETIME
|
||||
invalid_data = {
|
||||
"node_id": "invalid_test",
|
||||
"start_at": test_datetime,
|
||||
"status": "invalid_status",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "status" in str(exc_info.value)
|
||||
|
||||
def test_missing_required_fields_deserialization(self):
|
||||
"""Test deserialization with missing required fields."""
|
||||
incomplete_data = {"id": str(uuid.uuid4())}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(incomplete_data)
|
||||
error_str = str(exc_info.value)
|
||||
assert "node_id" in error_str or "start_at" in error_str
|
||||
|
||||
def test_invalid_datetime_deserialization(self):
|
||||
"""Test deserialization with invalid datetime values."""
|
||||
invalid_data = {
|
||||
"node_id": "datetime_test",
|
||||
"start_at": "invalid_datetime",
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "start_at" in str(exc_info.value)
|
||||
|
||||
def test_invalid_routes_structure_deserialization(self):
|
||||
"""Test RuntimeRouteState deserialization with invalid routes structure."""
|
||||
invalid_data = {
|
||||
"routes": "invalid_routes_structure", # Should be dict
|
||||
"node_state_mapping": {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RuntimeRouteState.model_validate(invalid_data)
|
||||
assert "routes" in str(exc_info.value)
|
||||
|
||||
def test_timezone_handling_in_datetime_fields(self):
|
||||
"""Test timezone handling in datetime field serialization."""
|
||||
utc_datetime = datetime.now(UTC)
|
||||
naive_datetime = utc_datetime.replace(tzinfo=None)
|
||||
|
||||
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
|
||||
dict_ = node_state.model_dump()
|
||||
|
||||
assert dict_["start_at"] == naive_datetime
|
||||
|
||||
# Test round trip
|
||||
reconstructed = RouteNodeState.model_validate(dict_)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
|
||||
json = node_state.model_dump_json()
|
||||
|
||||
reconstructed = RouteNodeState.model_validate_json(json)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
@ -0,0 +1,120 @@
|
||||
"""Tests for graph engine event handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
|
||||
from core.workflow.graph_engine.event_management.event_manager import EventManager
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
|
||||
|
||||
class _StubEdgeProcessor:
|
||||
"""Minimal edge processor stub for tests."""
|
||||
|
||||
|
||||
class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
def __init__(self, node_id: str) -> None:
|
||||
self.id = node_id
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = "Stub Node"
|
||||
self.execution_type = NodeExecutionType.EXECUTABLE
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
"""Construct an EventHandler with in-memory dependencies for testing."""
|
||||
|
||||
node = _StubNode(node_id)
|
||||
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
|
||||
|
||||
variable_pool = VariablePool()
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_execution = GraphExecution(workflow_id="test-workflow")
|
||||
|
||||
event_manager = EventManager()
|
||||
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
|
||||
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
|
||||
|
||||
handler = EventHandler(
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
graph_execution=graph_execution,
|
||||
response_coordinator=response_coordinator,
|
||||
event_collector=event_manager,
|
||||
edge_processor=_StubEdgeProcessor(),
|
||||
state_manager=state_manager,
|
||||
error_handler=_StubErrorHandler(),
|
||||
)
|
||||
|
||||
return handler, event_manager, graph_execution
|
||||
|
||||
|
||||
def test_retry_does_not_emit_additional_start_event() -> None:
|
||||
"""Ensure retry attempts do not produce duplicate start events."""
|
||||
|
||||
node_id = "test-node"
|
||||
handler, event_manager, graph_execution = _build_event_handler(node_id)
|
||||
|
||||
execution_id = "exec-1"
|
||||
node_type = NodeType.CODE
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(start_event)
|
||||
|
||||
retry_event = NodeRunRetryEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
error="boom",
|
||||
retry_index=1,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="boom",
|
||||
error_type="TestError",
|
||||
),
|
||||
)
|
||||
handler.dispatch(retry_event)
|
||||
|
||||
# Simulate the node starting execution again after retry
|
||||
second_start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(second_start_event)
|
||||
|
||||
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
|
||||
|
||||
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
|
||||
|
||||
node_execution = graph_execution.get_or_create_node_execution(node_id)
|
||||
assert node_execution.retry_count == 1
|
||||
@ -0,0 +1,37 @@
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_answer_end_with_text():
|
||||
fixture_name = "answer_end_with_text"
|
||||
case = WorkflowTestCase(
|
||||
fixture_name,
|
||||
query="Hello, AI!",
|
||||
expected_outputs={"answer": "prefixHello, AI!suffix"},
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent,
|
||||
# Start
|
||||
NodeRunStartedEvent,
|
||||
# The chunks are now emitted as the Answer node processes them
|
||||
# since sys.query is a special selector that gets attributed to
|
||||
# the active response node
|
||||
NodeRunStreamChunkEvent, # prefix
|
||||
NodeRunStreamChunkEvent, # sys.query
|
||||
NodeRunStreamChunkEvent, # suffix
|
||||
NodeRunSucceededEvent,
|
||||
# Answer
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
)
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
@ -0,0 +1,24 @@
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_array_iteration_formatting_workflow():
|
||||
"""
|
||||
Validate Iteration node processes [1,2,3] into formatted strings.
|
||||
|
||||
Fixture description expects:
|
||||
{"output": ["output: 1", "output: 2", "output: 3"]}
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="array_iteration_formatting_workflow",
|
||||
inputs={},
|
||||
expected_outputs={"output": ["output: 1", "output: 2", "output: 3"]},
|
||||
description="Iteration formats numbers into strings",
|
||||
use_auto_mock=True,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Iteration workflow failed: {result.error}"
|
||||
assert result.actual_outputs == test_case.expected_outputs
|
||||
@ -0,0 +1,356 @@
|
||||
"""
|
||||
Tests for the auto-mock system.
|
||||
|
||||
This module contains tests that validate the auto-mock functionality
|
||||
for workflows containing nodes that require third-party services.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_simple_llm_workflow_with_auto_mock():
|
||||
"""Test that a simple LLM workflow runs successfully with auto-mocking."""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Create mock configuration
|
||||
mock_config = MockConfigBuilder().with_llm_response("This is a test response from mocked LLM").build()
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="basic_llm_chat_workflow",
|
||||
inputs={"query": "Hello, how are you?"},
|
||||
expected_outputs={"answer": "This is a test response from mocked LLM"},
|
||||
description="Simple LLM workflow with auto-mock",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Workflow failed: {result.error}"
|
||||
assert result.actual_outputs is not None
|
||||
assert "answer" in result.actual_outputs
|
||||
assert result.actual_outputs["answer"] == "This is a test response from mocked LLM"
|
||||
|
||||
|
||||
def test_llm_workflow_with_custom_node_output():
|
||||
"""Test LLM workflow with custom output for specific node."""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Create mock configuration with custom output for specific node
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs(
|
||||
"llm_node",
|
||||
{
|
||||
"text": "Custom response for this specific node",
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="basic_llm_chat_workflow",
|
||||
inputs={"query": "Test query"},
|
||||
expected_outputs={"answer": "Custom response for this specific node"},
|
||||
description="LLM workflow with custom node output",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Workflow failed: {result.error}"
|
||||
assert result.actual_outputs is not None
|
||||
assert result.actual_outputs["answer"] == "Custom response for this specific node"
|
||||
|
||||
|
||||
def test_http_tool_workflow_with_auto_mock():
|
||||
"""Test workflow with HTTP request and tool nodes using auto-mock."""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Create mock configuration
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs(
|
||||
"http_node",
|
||||
{
|
||||
"status_code": 200,
|
||||
"body": '{"key": "value", "number": 42}',
|
||||
"headers": {"content-type": "application/json"},
|
||||
},
|
||||
)
|
||||
mock_config.set_node_outputs(
|
||||
"tool_node",
|
||||
{
|
||||
"result": {"key": "value", "number": 42},
|
||||
},
|
||||
)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="http_request_with_json_tool_workflow",
|
||||
inputs={"url": "https://api.example.com/data"},
|
||||
expected_outputs={
|
||||
"status_code": 200,
|
||||
"parsed_data": {"key": "value", "number": 42},
|
||||
},
|
||||
description="HTTP and Tool workflow with auto-mock",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Workflow failed: {result.error}"
|
||||
assert result.actual_outputs is not None
|
||||
assert result.actual_outputs["status_code"] == 200
|
||||
assert result.actual_outputs["parsed_data"] == {"key": "value", "number": 42}
|
||||
|
||||
|
||||
def test_workflow_with_simulated_node_error():
|
||||
"""Test that workflows handle simulated node errors correctly."""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Create mock configuration with error
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_error("llm_node", "Simulated LLM API error")
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="basic_llm_chat_workflow",
|
||||
inputs={"query": "This should fail"},
|
||||
expected_outputs={}, # We expect failure, so no outputs
|
||||
description="LLM workflow with simulated error",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
# The workflow should fail due to the simulated error
|
||||
assert not result.success
|
||||
assert result.error is not None
|
||||
|
||||
|
||||
def test_workflow_with_mock_delays():
|
||||
"""Test that mock delays work correctly."""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Create mock configuration with delays
|
||||
mock_config = MockConfig(simulate_delays=True)
|
||||
node_config = NodeMockConfig(
|
||||
node_id="llm_node",
|
||||
outputs={"text": "Response after delay"},
|
||||
delay=0.1, # 100ms delay
|
||||
)
|
||||
mock_config.set_node_config("llm_node", node_config)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="basic_llm_chat_workflow",
|
||||
inputs={"query": "Test with delay"},
|
||||
expected_outputs={"answer": "Response after delay"},
|
||||
description="LLM workflow with simulated delay",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Workflow failed: {result.error}"
|
||||
# Execution time should be at least the delay
|
||||
assert result.execution_time >= 0.1
|
||||
|
||||
|
||||
def test_mock_config_builder():
|
||||
"""Test the MockConfigBuilder fluent interface."""
|
||||
config = (
|
||||
MockConfigBuilder()
|
||||
.with_llm_response("LLM response")
|
||||
.with_agent_response("Agent response")
|
||||
.with_tool_response({"tool": "output"})
|
||||
.with_retrieval_response("Retrieval content")
|
||||
.with_http_response({"status_code": 201, "body": "created"})
|
||||
.with_node_output("node1", {"output": "value"})
|
||||
.with_node_error("node2", "error message")
|
||||
.with_delays(True)
|
||||
.build()
|
||||
)
|
||||
|
||||
assert config.default_llm_response == "LLM response"
|
||||
assert config.default_agent_response == "Agent response"
|
||||
assert config.default_tool_response == {"tool": "output"}
|
||||
assert config.default_retrieval_response == "Retrieval content"
|
||||
assert config.default_http_response == {"status_code": 201, "body": "created"}
|
||||
assert config.simulate_delays is True
|
||||
|
||||
node1_config = config.get_node_config("node1")
|
||||
assert node1_config is not None
|
||||
assert node1_config.outputs == {"output": "value"}
|
||||
|
||||
node2_config = config.get_node_config("node2")
|
||||
assert node2_config is not None
|
||||
assert node2_config.error == "error message"
|
||||
|
||||
|
||||
def test_mock_factory_node_type_detection():
|
||||
"""Test that MockNodeFactory correctly identifies nodes to mock."""
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None, # Will be set by test
|
||||
graph_runtime_state=None, # Will be set by test
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
# Test that third-party service nodes are identified for mocking
|
||||
assert factory.should_mock_node(NodeType.LLM)
|
||||
assert factory.should_mock_node(NodeType.AGENT)
|
||||
assert factory.should_mock_node(NodeType.TOOL)
|
||||
assert factory.should_mock_node(NodeType.KNOWLEDGE_RETRIEVAL)
|
||||
assert factory.should_mock_node(NodeType.HTTP_REQUEST)
|
||||
assert factory.should_mock_node(NodeType.PARAMETER_EXTRACTOR)
|
||||
assert factory.should_mock_node(NodeType.DOCUMENT_EXTRACTOR)
|
||||
|
||||
# Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy)
|
||||
assert factory.should_mock_node(NodeType.CODE)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
# Test that non-service nodes are not mocked
|
||||
assert not factory.should_mock_node(NodeType.START)
|
||||
assert not factory.should_mock_node(NodeType.END)
|
||||
assert not factory.should_mock_node(NodeType.IF_ELSE)
|
||||
assert not factory.should_mock_node(NodeType.VARIABLE_AGGREGATOR)
|
||||
|
||||
|
||||
def test_custom_mock_handler():
|
||||
"""Test using a custom handler function for mock outputs."""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Custom handler that modifies output based on input
|
||||
def custom_llm_handler(node) -> dict:
|
||||
# In a real scenario, we could access node.graph_runtime_state.variable_pool
|
||||
# to get the actual inputs
|
||||
return {
|
||||
"text": "Custom handler response",
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 8,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
|
||||
mock_config = MockConfig()
|
||||
node_config = NodeMockConfig(
|
||||
node_id="llm_node",
|
||||
custom_handler=custom_llm_handler,
|
||||
)
|
||||
mock_config.set_node_config("llm_node", node_config)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="basic_llm_chat_workflow",
|
||||
inputs={"query": "Test custom handler"},
|
||||
expected_outputs={"answer": "Custom handler response"},
|
||||
description="LLM workflow with custom handler",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Workflow failed: {result.error}"
|
||||
assert result.actual_outputs["answer"] == "Custom handler response"
|
||||
|
||||
|
||||
def test_workflow_without_auto_mock():
|
||||
"""Test that workflows work normally without auto-mock enabled."""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# This test uses the echo workflow which doesn't need external services
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="simple_passthrough_workflow",
|
||||
inputs={"query": "Test without mock"},
|
||||
expected_outputs={"query": "Test without mock"},
|
||||
description="Echo workflow without auto-mock",
|
||||
use_auto_mock=False, # Auto-mock disabled
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Workflow failed: {result.error}"
|
||||
assert result.actual_outputs["query"] == "Test without mock"
|
||||
|
||||
|
||||
def test_register_custom_mock_node():
|
||||
"""Test registering a custom mock implementation for a node type."""
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
# Create a custom mock for TemplateTransformNode
|
||||
class MockTemplateTransformNode(TemplateTransformNode):
|
||||
def _run(self):
|
||||
# Custom mock implementation
|
||||
pass
|
||||
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
# TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
# Unregister mock
|
||||
factory.unregister_mock_node_type(NodeType.TEMPLATE_TRANSFORM)
|
||||
assert not factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
# Re-register custom mock
|
||||
factory.register_mock_node_type(NodeType.TEMPLATE_TRANSFORM, MockTemplateTransformNode)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
|
||||
def test_default_config_by_node_type():
|
||||
"""Test setting default configurations by node type."""
|
||||
mock_config = MockConfig()
|
||||
|
||||
# Set default config for all LLM nodes
|
||||
mock_config.set_default_config(
|
||||
NodeType.LLM,
|
||||
{
|
||||
"default_response": "Default LLM response for all nodes",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
# Set default config for all HTTP nodes
|
||||
mock_config.set_default_config(
|
||||
NodeType.HTTP_REQUEST,
|
||||
{
|
||||
"default_status": 200,
|
||||
"default_timeout": 30,
|
||||
},
|
||||
)
|
||||
|
||||
llm_config = mock_config.get_default_config(NodeType.LLM)
|
||||
assert llm_config["default_response"] == "Default LLM response for all nodes"
|
||||
assert llm_config["temperature"] == 0.7
|
||||
|
||||
http_config = mock_config.get_default_config(NodeType.HTTP_REQUEST)
|
||||
assert http_config["default_status"] == 200
|
||||
assert http_config["default_timeout"] == 30
|
||||
|
||||
# Non-configured node type should return empty dict
|
||||
tool_config = mock_config.get_default_config(NodeType.TOOL)
|
||||
assert tool_config == {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests
|
||||
pytest.main([__file__, "-v"])
|
||||
@ -0,0 +1,41 @@
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_basic_chatflow():
|
||||
fixture_name = "basic_chatflow"
|
||||
mock_config = MockConfigBuilder().with_llm_response("mocked llm response").build()
|
||||
case = WorkflowTestCase(
|
||||
fixture_path=fixture_name,
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
expected_outputs={"answer": "mocked llm response"},
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent,
|
||||
# START
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# LLM
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
+ [NodeRunStreamChunkEvent] * ("mocked llm response".count(" ") + 2)
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
# ANSWER
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
)
|
||||
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
@ -0,0 +1,107 @@
|
||||
"""Test the command system for GraphEngine control."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
|
||||
|
||||
|
||||
def test_abort_command():
|
||||
"""Test that GraphEngine properly handles abort commands."""
|
||||
|
||||
# Create shared GraphRuntimeState
|
||||
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a minimal mock graph
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
# Create mock nodes with required attributes - using shared runtime state
|
||||
mock_start_node = MagicMock()
|
||||
mock_start_node.state = None
|
||||
mock_start_node.id = "start"
|
||||
mock_start_node.graph_runtime_state = shared_runtime_state # Use shared instance
|
||||
mock_graph.nodes["start"] = mock_start_node
|
||||
|
||||
# Mock graph methods
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
# Create command channel
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
# Create GraphEngine with same shared runtime state
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=shared_runtime_state, # Use shared instance
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
# Send abort command before starting
|
||||
abort_command = AbortCommand(reason="Test abort")
|
||||
command_channel.send_command(abort_command)
|
||||
|
||||
# Run engine and collect events
|
||||
events = list(engine.run())
|
||||
|
||||
# Verify we get start and abort events
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunAbortedEvent) for e in events)
|
||||
|
||||
# Find the abort event and check its reason
|
||||
abort_events = [e for e in events if isinstance(e, GraphRunAbortedEvent)]
|
||||
assert len(abort_events) == 1
|
||||
assert abort_events[0].reason is not None
|
||||
assert "aborted: test abort" in abort_events[0].reason.lower()
|
||||
|
||||
|
||||
def test_redis_channel_serialization():
|
||||
"""Test that Redis channel properly serializes and deserializes commands."""
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Mock redis client
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
|
||||
# Create channel with a specific key
|
||||
channel = RedisChannel(mock_redis, channel_key="workflow:123:commands")
|
||||
|
||||
# Test sending a command
|
||||
abort_command = AbortCommand(reason="Test abort")
|
||||
channel.send_command(abort_command)
|
||||
|
||||
# Verify redis methods were called
|
||||
mock_pipeline.rpush.assert_called_once()
|
||||
mock_pipeline.expire.assert_called_once()
|
||||
|
||||
# Verify the serialized data
|
||||
call_args = mock_pipeline.rpush.call_args
|
||||
key = call_args[0][0]
|
||||
command_json = call_args[0][1]
|
||||
|
||||
assert key == "workflow:123:commands"
|
||||
|
||||
# Verify JSON structure
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == "abort"
|
||||
assert command_data["reason"] == "Test abort"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_abort_command()
|
||||
test_redis_channel_serialization()
|
||||
print("All tests passed!")
|
||||
@ -0,0 +1,134 @@
|
||||
"""
|
||||
Test suite for complex branch workflow with parallel execution and conditional routing.
|
||||
|
||||
This test suite validates the behavior of a workflow that:
|
||||
1. Executes nodes in parallel (IF/ELSE and LLM branches)
|
||||
2. Routes based on conditional logic (query containing 'hello')
|
||||
3. Handles multiple answer nodes with different outputs
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
class TestComplexBranchWorkflow:
|
||||
"""Test suite for complex branch workflow with parallel execution."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment before each test method."""
|
||||
self.runner = TableTestRunner()
|
||||
self.fixture_path = "test_complex_branch"
|
||||
|
||||
@pytest.mark.skip(reason="output in this workflow can be random")
|
||||
def test_hello_branch_with_llm(self):
|
||||
"""
|
||||
Test when query contains 'hello' - should trigger true branch.
|
||||
Both IF/ELSE and LLM should execute in parallel.
|
||||
"""
|
||||
mock_text_1 = "This is a mocked LLM response for hello world"
|
||||
test_cases = [
|
||||
WorkflowTestCase(
|
||||
fixture_path=self.fixture_path,
|
||||
query="hello world",
|
||||
expected_outputs={
|
||||
"answer": f"{mock_text_1}contains 'hello'",
|
||||
},
|
||||
description="Basic hello case with parallel LLM execution",
|
||||
use_auto_mock=True,
|
||||
mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()),
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent,
|
||||
# Start
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# If/Else (no streaming)
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# LLM (with streaming)
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
# LLM
|
||||
+ [NodeRunStreamChunkEvent] * (mock_text_1.count(" ") + 2)
|
||||
+ [
|
||||
# Answer's text
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Answer
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Answer 2
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
),
|
||||
WorkflowTestCase(
|
||||
fixture_path=self.fixture_path,
|
||||
query="say hello to everyone",
|
||||
expected_outputs={
|
||||
"answer": "Mocked response for greetingcontains 'hello'",
|
||||
},
|
||||
description="Hello in middle of sentence",
|
||||
use_auto_mock=True,
|
||||
mock_config=(
|
||||
MockConfigBuilder()
|
||||
.with_node_output("1755502777322", {"text": "Mocked response for greeting"})
|
||||
.build()
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
suite_result = self.runner.run_table_tests(test_cases)
|
||||
|
||||
for result in suite_result.results:
|
||||
assert result.success, f"Test '{result.test_case.description}' failed: {result.error}"
|
||||
assert result.actual_outputs
|
||||
|
||||
def test_non_hello_branch_with_llm(self):
|
||||
"""
|
||||
Test when query doesn't contain 'hello' - should trigger false branch.
|
||||
LLM output should be used as the final answer.
|
||||
"""
|
||||
test_cases = [
|
||||
WorkflowTestCase(
|
||||
fixture_path=self.fixture_path,
|
||||
query="goodbye world",
|
||||
expected_outputs={
|
||||
"answer": "Mocked LLM response for goodbye",
|
||||
},
|
||||
description="Goodbye case - false branch with LLM output",
|
||||
use_auto_mock=True,
|
||||
mock_config=(
|
||||
MockConfigBuilder()
|
||||
.with_node_output("1755502777322", {"text": "Mocked LLM response for goodbye"})
|
||||
.build()
|
||||
),
|
||||
),
|
||||
WorkflowTestCase(
|
||||
fixture_path=self.fixture_path,
|
||||
query="test message",
|
||||
expected_outputs={
|
||||
"answer": "Mocked response for test",
|
||||
},
|
||||
description="Regular message - false branch",
|
||||
use_auto_mock=True,
|
||||
mock_config=(
|
||||
MockConfigBuilder().with_node_output("1755502777322", {"text": "Mocked response for test"}).build()
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
suite_result = self.runner.run_table_tests(test_cases)
|
||||
|
||||
for result in suite_result.results:
|
||||
assert result.success, f"Test '{result.test_case.description}' failed: {result.error}"
|
||||
@ -0,0 +1,210 @@
|
||||
"""
|
||||
Test for streaming output workflow behavior.
|
||||
|
||||
This test validates that:
|
||||
- When blocking == 1: No NodeRunStreamChunkEvent (flow through Template node)
|
||||
- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output)
|
||||
"""
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_table_runner import TableTestRunner
|
||||
|
||||
|
||||
def test_streaming_output_with_blocking_equals_one():
|
||||
"""
|
||||
Test workflow when blocking == 1 (LLM → Template → End).
|
||||
|
||||
Template node doesn't produce streaming output, so no NodeRunStreamChunkEvent should be present.
|
||||
This test should FAIL according to requirements.
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Load the workflow configuration
|
||||
fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow")
|
||||
|
||||
# Create graph from fixture with auto-mock enabled
|
||||
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
inputs={"query": "Hello, how are you?", "blocking": 1},
|
||||
use_mock_factory=True,
|
||||
)
|
||||
|
||||
# Create and run the engine
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Execute the workflow
|
||||
events = list(engine.run())
|
||||
|
||||
# Check for successful completion
|
||||
success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)]
|
||||
assert len(success_events) > 0, "Workflow should complete successfully"
|
||||
|
||||
# Check for streaming events
|
||||
stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)]
|
||||
stream_chunk_count = len(stream_chunk_events)
|
||||
|
||||
# According to requirements, we expect exactly 3 streaming events from the End node
|
||||
# 1. User query
|
||||
# 2. Newline
|
||||
# 3. Template output (which contains the LLM response)
|
||||
assert stream_chunk_count == 3, f"Expected 3 streaming events when blocking=1, but got {stream_chunk_count}"
|
||||
|
||||
first_chunk, second_chunk, third_chunk = stream_chunk_events[0], stream_chunk_events[1], stream_chunk_events[2]
|
||||
assert first_chunk.chunk == "Hello, how are you?", (
|
||||
f"Expected first chunk to be user input, but got {first_chunk.chunk}"
|
||||
)
|
||||
assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}"
|
||||
# Third chunk will be the template output with the mock LLM response
|
||||
assert isinstance(third_chunk.chunk, str), f"Expected third chunk to be string, but got {type(third_chunk.chunk)}"
|
||||
|
||||
# Find indices of first LLM success event and first stream chunk event
|
||||
llm2_start_index = next(
|
||||
(i for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM),
|
||||
-1,
|
||||
)
|
||||
first_chunk_index = next(
|
||||
(i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)),
|
||||
-1,
|
||||
)
|
||||
|
||||
assert first_chunk_index < llm2_start_index, (
|
||||
f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}"
|
||||
)
|
||||
|
||||
# Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent
|
||||
start_node_id = graph.root_node.id
|
||||
start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id]
|
||||
assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}"
|
||||
start_event = start_events[0]
|
||||
query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"]
|
||||
assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id"
|
||||
|
||||
# Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent
|
||||
start_events = [
|
||||
e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.TEMPLATE_TRANSFORM
|
||||
]
|
||||
template_chunk_events = [e for e in stream_chunk_events if e.node_type == NodeType.TEMPLATE_TRANSFORM]
|
||||
assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}"
|
||||
assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), (
|
||||
"Expected all Template chunk events to have same id with Template's NodeRunStartedEvent"
|
||||
)
|
||||
|
||||
# Check that NodeRunStreamChunkEvent contains '\n' is from the End node
|
||||
end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.END]
|
||||
assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}"
|
||||
newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"]
|
||||
assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}"
|
||||
# The newline chunk should be from the End node (check node_id, not execution id)
|
||||
assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), (
|
||||
"Expected all newline chunk events to be from End node"
|
||||
)
|
||||
|
||||
|
||||
def test_streaming_output_with_blocking_not_equals_one():
|
||||
"""
|
||||
Test workflow when blocking != 1 (LLM → End directly).
|
||||
|
||||
End node should produce streaming output with NodeRunStreamChunkEvent.
|
||||
This test should PASS according to requirements.
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Load the workflow configuration
|
||||
fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow")
|
||||
|
||||
# Create graph from fixture with auto-mock enabled
|
||||
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
inputs={"query": "Hello, how are you?", "blocking": 2},
|
||||
use_mock_factory=True,
|
||||
)
|
||||
|
||||
# Create and run the engine
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Execute the workflow
|
||||
events = list(engine.run())
|
||||
|
||||
# Check for successful completion
|
||||
success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)]
|
||||
assert len(success_events) > 0, "Workflow should complete successfully"
|
||||
|
||||
# Check for streaming events - expecting streaming events
|
||||
stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)]
|
||||
stream_chunk_count = len(stream_chunk_events)
|
||||
|
||||
# This assertion should PASS according to requirements
|
||||
assert stream_chunk_count > 0, f"Expected streaming events when blocking!=1, but got {stream_chunk_count}"
|
||||
|
||||
# We should have at least 2 chunks (query and newline)
|
||||
assert stream_chunk_count >= 2, f"Expected at least 2 streaming events, but got {stream_chunk_count}"
|
||||
|
||||
first_chunk, second_chunk = stream_chunk_events[0], stream_chunk_events[1]
|
||||
assert first_chunk.chunk == "Hello, how are you?", (
|
||||
f"Expected first chunk to be user input, but got {first_chunk.chunk}"
|
||||
)
|
||||
assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}"
|
||||
|
||||
# Find indices of first LLM success event and first stream chunk event
|
||||
llm2_start_index = next(
|
||||
(i for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM),
|
||||
-1,
|
||||
)
|
||||
first_chunk_index = next(
|
||||
(i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)),
|
||||
-1,
|
||||
)
|
||||
|
||||
assert first_chunk_index < llm2_start_index, (
|
||||
f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}"
|
||||
)
|
||||
|
||||
# With auto-mock, the LLM will produce mock responses - just verify we have streaming chunks
|
||||
# and they are strings
|
||||
for chunk_event in stream_chunk_events[2:]:
|
||||
assert isinstance(chunk_event.chunk, str), f"Expected chunk to be string, but got {type(chunk_event.chunk)}"
|
||||
|
||||
# Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent
|
||||
start_node_id = graph.root_node.id
|
||||
start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id]
|
||||
assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}"
|
||||
start_event = start_events[0]
|
||||
query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"]
|
||||
assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id"
|
||||
|
||||
# Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes
|
||||
start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.LLM]
|
||||
llm_chunk_events = [e for e in stream_chunk_events if e.node_type == NodeType.LLM]
|
||||
llm_node_ids = {se.node_id for se in start_events}
|
||||
assert all(e.node_id in llm_node_ids for e in llm_chunk_events), (
|
||||
"Expected all LLM chunk events to be from LLM nodes"
|
||||
)
|
||||
|
||||
# Check that NodeRunStreamChunkEvent contains '\n' is from the End node
|
||||
end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.END]
|
||||
assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}"
|
||||
newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"]
|
||||
assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}"
|
||||
# The newline chunk should be from the End node (check node_id, not execution id)
|
||||
assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), (
|
||||
"Expected all newline chunk events to be from End node"
|
||||
)
|
||||
@ -1,791 +0,0 @@
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
def test_init():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-source-answer-target",
|
||||
"source": "llm",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "start-source-qc-target",
|
||||
"source": "start",
|
||||
"target": "qc",
|
||||
},
|
||||
{
|
||||
"id": "qc-1-llm-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "1",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "qc-2-http-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "2",
|
||||
"target": "http",
|
||||
},
|
||||
{
|
||||
"id": "http-source-answer2-target",
|
||||
"source": "http",
|
||||
"target": "answer2",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "question-classifier"},
|
||||
"id": "qc",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
},
|
||||
"id": "http",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
start_node_id = "start"
|
||||
|
||||
assert graph.root_node_id == start_node_id
|
||||
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
|
||||
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
|
||||
|
||||
|
||||
def test__init_iteration_graph():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-answer",
|
||||
"source": "llm",
|
||||
"sourceHandle": "source",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "iteration-source-llm-target",
|
||||
"source": "iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
|
||||
"source": "template-transform-in-iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "llm-in-iteration",
|
||||
},
|
||||
{
|
||||
"id": "llm-in-iteration-source-answer-in-iteration-target",
|
||||
"source": "llm-in-iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "answer-in-iteration",
|
||||
},
|
||||
{
|
||||
"id": "start-source-code-target",
|
||||
"source": "start",
|
||||
"sourceHandle": "source",
|
||||
"target": "code",
|
||||
},
|
||||
{
|
||||
"id": "code-source-iteration-target",
|
||||
"source": "code",
|
||||
"sourceHandle": "source",
|
||||
"target": "iteration",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start",
|
||||
},
|
||||
"id": "start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "iteration"},
|
||||
"id": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
},
|
||||
"id": "template-transform-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration")
|
||||
graph.add_extra_edge(
|
||||
source_node_id="answer-in-iteration",
|
||||
target_node_id="template-transform-in-iteration",
|
||||
run_condition=RunCondition(
|
||||
type="condition",
|
||||
conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")],
|
||||
),
|
||||
)
|
||||
|
||||
# iteration:
|
||||
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
|
||||
|
||||
assert graph.root_node_id == "template-transform-in-iteration"
|
||||
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
|
||||
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
|
||||
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
|
||||
|
||||
|
||||
def test_parallels_graph():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-answer-target",
|
||||
"source": "llm3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
start_edges = graph.edge_mapping.get("start")
|
||||
assert start_edges is not None
|
||||
assert start_edges[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
llm_edges = graph.edge_mapping.get(f"llm{i + 1}")
|
||||
assert llm_edges is not None
|
||||
assert llm_edges[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph2():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
if i < 2:
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph3():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph4():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-code3-target",
|
||||
"source": "llm3",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code3-source-answer-target",
|
||||
"source": "code3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
|
||||
assert graph.edge_mapping.get(f"code{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 6
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph5():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code1-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-code1-target",
|
||||
"source": "llm2",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-code2-target",
|
||||
"source": "llm3",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-code2-target",
|
||||
"source": "llm4",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-code3-target",
|
||||
"source": "llm5",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(5):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm2") is not None
|
||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm3") is not None
|
||||
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm4") is not None
|
||||
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm5") is not None
|
||||
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
|
||||
assert graph.edge_mapping.get("code1") is not None
|
||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code2") is not None
|
||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 8
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph6():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code1-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code2-target",
|
||||
"source": "llm1",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-code3-target",
|
||||
"source": "llm2",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code3-source-answer-target",
|
||||
"source": "code3",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-answer-target",
|
||||
"source": "llm3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm2") is not None
|
||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
|
||||
assert graph.edge_mapping.get("code1") is not None
|
||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code2") is not None
|
||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code3") is not None
|
||||
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 2
|
||||
assert len(graph.node_parallel_mapping) == 6
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
parent_parallel = None
|
||||
child_parallel = None
|
||||
for p_id, parallel in graph.parallel_mapping.items():
|
||||
if parallel.parent_parallel_id is None:
|
||||
parent_parallel = parallel
|
||||
else:
|
||||
child_parallel = parallel
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code3"]:
|
||||
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
|
||||
|
||||
for node_id in ["code1", "code2"]:
|
||||
assert graph.node_parallel_mapping[node_id] == child_parallel.id
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,194 @@
|
||||
"""Unit tests for GraphExecution serialization helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.graph_engine.domain import GraphExecution
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_engine.response_coordinator.path import Path
|
||||
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
|
||||
|
||||
|
||||
class CustomGraphExecutionError(Exception):
|
||||
"""Custom exception used to verify error serialization."""
|
||||
|
||||
|
||||
def test_graph_execution_serialization_round_trip() -> None:
|
||||
"""GraphExecution serialization restores full aggregate state."""
|
||||
# Arrange
|
||||
execution = GraphExecution(workflow_id="wf-1")
|
||||
execution.start()
|
||||
node_a = execution.get_or_create_node_execution("node-a")
|
||||
node_a.mark_started(execution_id="exec-1")
|
||||
node_a.increment_retry()
|
||||
node_a.mark_failed("boom")
|
||||
node_b = execution.get_or_create_node_execution("node-b")
|
||||
node_b.mark_skipped()
|
||||
execution.fail(CustomGraphExecutionError("serialization failure"))
|
||||
|
||||
# Act
|
||||
serialized = execution.dumps()
|
||||
payload = json.loads(serialized)
|
||||
restored = GraphExecution(workflow_id="wf-1")
|
||||
restored.loads(serialized)
|
||||
|
||||
# Assert
|
||||
assert payload["type"] == "GraphExecution"
|
||||
assert payload["version"] == "1.0"
|
||||
assert restored.workflow_id == "wf-1"
|
||||
assert restored.started is True
|
||||
assert restored.completed is True
|
||||
assert restored.aborted is False
|
||||
assert isinstance(restored.error, CustomGraphExecutionError)
|
||||
assert str(restored.error) == "serialization failure"
|
||||
assert set(restored.node_executions) == {"node-a", "node-b"}
|
||||
restored_node_a = restored.node_executions["node-a"]
|
||||
assert restored_node_a.state is NodeState.TAKEN
|
||||
assert restored_node_a.retry_count == 1
|
||||
assert restored_node_a.execution_id == "exec-1"
|
||||
assert restored_node_a.error == "boom"
|
||||
restored_node_b = restored.node_executions["node-b"]
|
||||
assert restored_node_b.state is NodeState.SKIPPED
|
||||
assert restored_node_b.retry_count == 0
|
||||
assert restored_node_b.execution_id is None
|
||||
assert restored_node_b.error is None
|
||||
|
||||
|
||||
def test_graph_execution_loads_replaces_existing_state() -> None:
|
||||
"""loads replaces existing runtime data with serialized snapshot."""
|
||||
# Arrange
|
||||
source = GraphExecution(workflow_id="wf-2")
|
||||
source.start()
|
||||
source_node = source.get_or_create_node_execution("node-source")
|
||||
source_node.mark_taken()
|
||||
serialized = source.dumps()
|
||||
|
||||
target = GraphExecution(workflow_id="wf-2")
|
||||
target.start()
|
||||
target.abort("pre-existing abort")
|
||||
temp_node = target.get_or_create_node_execution("node-temp")
|
||||
temp_node.increment_retry()
|
||||
temp_node.mark_failed("temp error")
|
||||
|
||||
# Act
|
||||
target.loads(serialized)
|
||||
|
||||
# Assert
|
||||
assert target.aborted is False
|
||||
assert target.error is None
|
||||
assert target.started is True
|
||||
assert target.completed is False
|
||||
assert set(target.node_executions) == {"node-source"}
|
||||
restored_node = target.node_executions["node-source"]
|
||||
assert restored_node.state is NodeState.TAKEN
|
||||
assert restored_node.retry_count == 0
|
||||
assert restored_node.execution_id is None
|
||||
assert restored_node.error is None
|
||||
|
||||
|
||||
def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None:
|
||||
"""ResponseStreamCoordinator serialization restores coordinator internals."""
|
||||
|
||||
template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])])
|
||||
template_secondary = Template(segments=[TextSegment(text="secondary")])
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None:
|
||||
self.id = node_id
|
||||
self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM
|
||||
self.execution_type = execution_type
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = node_id
|
||||
self.template = template
|
||||
|
||||
def blocks_variable_output(self, *_args) -> bool:
|
||||
return False
|
||||
|
||||
response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE)
|
||||
response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE)
|
||||
response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE)
|
||||
source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE)
|
||||
|
||||
class DummyGraph:
|
||||
def __init__(self) -> None:
|
||||
self.nodes = {
|
||||
response_node1.id: response_node1,
|
||||
response_node2.id: response_node2,
|
||||
response_node3.id: response_node3,
|
||||
source_node.id: source_node,
|
||||
}
|
||||
self.edges: dict[str, object] = {}
|
||||
self.root_node = response_node1
|
||||
|
||||
def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
graph = DummyGraph()
|
||||
|
||||
def fake_from_node(cls, node: DummyNode) -> ResponseSession:
|
||||
return ResponseSession(node_id=node.id, template=node.template)
|
||||
|
||||
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||
coordinator._response_nodes = {"response-1", "response-2", "response-3"}
|
||||
coordinator._paths_maps = {
|
||||
"response-1": [Path(edges=["edge-1"])],
|
||||
"response-2": [Path(edges=[])],
|
||||
"response-3": [Path(edges=["edge-2", "edge-3"])],
|
||||
}
|
||||
|
||||
active_session = ResponseSession(node_id="response-1", template=response_node1.template)
|
||||
active_session.index = 1
|
||||
coordinator._active_session = active_session
|
||||
waiting_session = ResponseSession(node_id="response-2", template=response_node2.template)
|
||||
coordinator._waiting_sessions = deque([waiting_session])
|
||||
pending_session = ResponseSession(node_id="response-3", template=response_node3.template)
|
||||
pending_session.index = 2
|
||||
coordinator._response_sessions = {"response-3": pending_session}
|
||||
|
||||
coordinator._node_execution_ids = {"response-1": "exec-1"}
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id="exec-1",
|
||||
node_id="response-1",
|
||||
node_type=NodeType.ANSWER,
|
||||
selector=["node-source", "text"],
|
||||
chunk="chunk-1",
|
||||
is_final=False,
|
||||
)
|
||||
coordinator._stream_buffers = {("node-source", "text"): [event]}
|
||||
coordinator._stream_positions = {("node-source", "text"): 1}
|
||||
coordinator._closed_streams = {("node-source", "text")}
|
||||
|
||||
serialized = coordinator.dumps()
|
||||
|
||||
restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||
restored.loads(serialized)
|
||||
|
||||
assert restored._response_nodes == {"response-1", "response-2", "response-3"}
|
||||
assert restored._paths_maps["response-1"][0].edges == ["edge-1"]
|
||||
assert restored._active_session is not None
|
||||
assert restored._active_session.node_id == "response-1"
|
||||
assert restored._active_session.index == 1
|
||||
waiting_restored = list(restored._waiting_sessions)
|
||||
assert len(waiting_restored) == 1
|
||||
assert waiting_restored[0].node_id == "response-2"
|
||||
assert waiting_restored[0].index == 0
|
||||
assert set(restored._response_sessions) == {"response-3"}
|
||||
assert restored._response_sessions["response-3"].index == 2
|
||||
assert restored._node_execution_ids == {"response-1": "exec-1"}
|
||||
assert ("node-source", "text") in restored._stream_buffers
|
||||
restored_event = restored._stream_buffers[("node-source", "text")][0]
|
||||
assert restored_event.chunk == "chunk-1"
|
||||
assert restored._stream_positions[("node-source", "text")] == 1
|
||||
assert ("node-source", "text") in restored._closed_streams
|
||||
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Test case for loop with inner answer output error scenario.
|
||||
|
||||
This test validates the behavior of a loop containing an answer node
|
||||
inside the loop that may produce output errors.
|
||||
"""
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_loop_contains_answer():
|
||||
"""
|
||||
Test loop with inner answer node that may have output errors.
|
||||
|
||||
The fixture implements a loop that:
|
||||
1. Iterates 4 times (index 0-3)
|
||||
2. Contains an inner answer node that outputs index and item values
|
||||
3. Has a break condition when index equals 4
|
||||
4. Tests error handling for answer nodes within loops
|
||||
"""
|
||||
fixture_name = "loop_contains_answer"
|
||||
mock_config = MockConfigBuilder().build()
|
||||
|
||||
case = WorkflowTestCase(
|
||||
fixture_path=fixture_name,
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
query="1",
|
||||
expected_outputs={"answer": "1\n2\n1 + 2"},
|
||||
expected_event_sequence=[
|
||||
# Graph start
|
||||
GraphRunStartedEvent,
|
||||
# Start
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Loop start
|
||||
NodeRunStartedEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
# Variable assigner
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent, # 1
|
||||
NodeRunStreamChunkEvent, # \n
|
||||
NodeRunSucceededEvent,
|
||||
# Answer
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Loop next
|
||||
NodeRunLoopNextEvent,
|
||||
# Variable assigner
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent, # 2
|
||||
NodeRunStreamChunkEvent, # \n
|
||||
NodeRunSucceededEvent,
|
||||
# Answer
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Loop end
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunStreamChunkEvent, # 1
|
||||
NodeRunStreamChunkEvent, # +
|
||||
NodeRunStreamChunkEvent, # 2
|
||||
NodeRunSucceededEvent,
|
||||
# Answer
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Graph end
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
)
|
||||
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
@ -0,0 +1,41 @@
|
||||
"""
|
||||
Test cases for the Loop node functionality using TableTestRunner.
|
||||
|
||||
This module tests the loop node's ability to:
|
||||
1. Execute iterations with loop variables
|
||||
2. Handle break conditions correctly
|
||||
3. Update and propagate loop variables between iterations
|
||||
4. Output the final loop variable value
|
||||
"""
|
||||
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_table_runner import (
|
||||
TableTestRunner,
|
||||
WorkflowTestCase,
|
||||
)
|
||||
|
||||
|
||||
def test_loop_with_break_condition():
|
||||
"""
|
||||
Test loop node with break condition.
|
||||
|
||||
The increment_loop_with_break_condition_workflow.yml fixture implements a loop that:
|
||||
1. Starts with num=1
|
||||
2. Increments num by 1 each iteration
|
||||
3. Breaks when num >= 5
|
||||
4. Should output {"num": 5}
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="increment_loop_with_break_condition_workflow",
|
||||
inputs={}, # No inputs needed for this test
|
||||
expected_outputs={"num": 5},
|
||||
description="Loop with break condition when num >= 5",
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
# Assert the test passed
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
assert result.actual_outputs is not None, "Should have outputs"
|
||||
assert result.actual_outputs == {"num": 5}, f"Expected {{'num': 5}}, got {result.actual_outputs}"
|
||||
@ -0,0 +1,67 @@
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_loop_with_tool():
|
||||
fixture_name = "search_dify_from_2023_to_2025"
|
||||
mock_config = (
|
||||
MockConfigBuilder()
|
||||
.with_tool_response(
|
||||
{
|
||||
"text": "mocked search result",
|
||||
}
|
||||
)
|
||||
.build()
|
||||
)
|
||||
case = WorkflowTestCase(
|
||||
fixture_path=fixture_name,
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
expected_outputs={
|
||||
"answer": """- mocked search result
|
||||
- mocked search result"""
|
||||
},
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent,
|
||||
# START
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# LOOP START
|
||||
NodeRunStartedEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
# 2023
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
# 2024
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# LOOP END
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunStreamChunkEvent, # loop.res
|
||||
NodeRunSucceededEvent,
|
||||
# ANSWER
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
)
|
||||
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
@ -0,0 +1,165 @@
|
||||
"""
|
||||
Configuration system for mock nodes in testing.
|
||||
|
||||
This module provides a flexible configuration system for customizing
|
||||
the behavior of mock nodes during testing.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeMockConfig:
|
||||
"""Configuration for a specific node mock."""
|
||||
|
||||
node_id: str
|
||||
outputs: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
delay: float = 0.0 # Simulated execution delay in seconds
|
||||
custom_handler: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
"""
|
||||
Global configuration for mock nodes in a test.
|
||||
|
||||
This configuration allows tests to customize the behavior of mock nodes,
|
||||
including their outputs, errors, and execution characteristics.
|
||||
"""
|
||||
|
||||
# Node-specific configurations by node ID
|
||||
node_configs: dict[str, NodeMockConfig] = field(default_factory=dict)
|
||||
|
||||
# Default configurations by node type
|
||||
default_configs: dict[NodeType, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
# Global settings
|
||||
enable_auto_mock: bool = True
|
||||
simulate_delays: bool = False
|
||||
default_llm_response: str = "This is a mocked LLM response"
|
||||
default_agent_response: str = "This is a mocked agent response"
|
||||
default_tool_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked tool output"})
|
||||
default_retrieval_response: str = "This is mocked retrieval content"
|
||||
default_http_response: dict[str, Any] = field(
|
||||
default_factory=lambda: {"status_code": 200, "body": "mocked response", "headers": {}}
|
||||
)
|
||||
default_template_transform_response: str = "This is mocked template transform output"
|
||||
default_code_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked code execution result"})
|
||||
|
||||
def get_node_config(self, node_id: str) -> NodeMockConfig | None:
|
||||
"""Get configuration for a specific node."""
|
||||
return self.node_configs.get(node_id)
|
||||
|
||||
def set_node_config(self, node_id: str, config: NodeMockConfig) -> None:
|
||||
"""Set configuration for a specific node."""
|
||||
self.node_configs[node_id] = config
|
||||
|
||||
def set_node_outputs(self, node_id: str, outputs: dict[str, Any]) -> None:
|
||||
"""Set expected outputs for a specific node."""
|
||||
if node_id not in self.node_configs:
|
||||
self.node_configs[node_id] = NodeMockConfig(node_id=node_id)
|
||||
self.node_configs[node_id].outputs = outputs
|
||||
|
||||
def set_node_error(self, node_id: str, error: str) -> None:
|
||||
"""Set an error for a specific node to simulate failure."""
|
||||
if node_id not in self.node_configs:
|
||||
self.node_configs[node_id] = NodeMockConfig(node_id=node_id)
|
||||
self.node_configs[node_id].error = error
|
||||
|
||||
def get_default_config(self, node_type: NodeType) -> dict[str, Any]:
|
||||
"""Get default configuration for a node type."""
|
||||
return self.default_configs.get(node_type, {})
|
||||
|
||||
def set_default_config(self, node_type: NodeType, config: dict[str, Any]) -> None:
|
||||
"""Set default configuration for a node type."""
|
||||
self.default_configs[node_type] = config
|
||||
|
||||
|
||||
class MockConfigBuilder:
|
||||
"""
|
||||
Builder for creating MockConfig instances with a fluent interface.
|
||||
|
||||
Example:
|
||||
config = (MockConfigBuilder()
|
||||
.with_llm_response("Custom LLM response")
|
||||
.with_node_output("node_123", {"text": "specific output"})
|
||||
.with_node_error("node_456", "Simulated error")
|
||||
.build())
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._config = MockConfig()
|
||||
|
||||
def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder":
|
||||
"""Enable or disable auto-mocking."""
|
||||
self._config.enable_auto_mock = enabled
|
||||
return self
|
||||
|
||||
def with_delays(self, enabled: bool = True) -> "MockConfigBuilder":
|
||||
"""Enable or disable simulated execution delays."""
|
||||
self._config.simulate_delays = enabled
|
||||
return self
|
||||
|
||||
def with_llm_response(self, response: str) -> "MockConfigBuilder":
|
||||
"""Set default LLM response."""
|
||||
self._config.default_llm_response = response
|
||||
return self
|
||||
|
||||
def with_agent_response(self, response: str) -> "MockConfigBuilder":
|
||||
"""Set default agent response."""
|
||||
self._config.default_agent_response = response
|
||||
return self
|
||||
|
||||
def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
||||
"""Set default tool response."""
|
||||
self._config.default_tool_response = response
|
||||
return self
|
||||
|
||||
def with_retrieval_response(self, response: str) -> "MockConfigBuilder":
|
||||
"""Set default retrieval response."""
|
||||
self._config.default_retrieval_response = response
|
||||
return self
|
||||
|
||||
def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
||||
"""Set default HTTP response."""
|
||||
self._config.default_http_response = response
|
||||
return self
|
||||
|
||||
def with_template_transform_response(self, response: str) -> "MockConfigBuilder":
|
||||
"""Set default template transform response."""
|
||||
self._config.default_template_transform_response = response
|
||||
return self
|
||||
|
||||
def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
||||
"""Set default code execution response."""
|
||||
self._config.default_code_response = response
|
||||
return self
|
||||
|
||||
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder":
|
||||
"""Set outputs for a specific node."""
|
||||
self._config.set_node_outputs(node_id, outputs)
|
||||
return self
|
||||
|
||||
def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder":
|
||||
"""Set error for a specific node."""
|
||||
self._config.set_node_error(node_id, error)
|
||||
return self
|
||||
|
||||
def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder":
|
||||
"""Add a node-specific configuration."""
|
||||
self._config.set_node_config(config.node_id, config)
|
||||
return self
|
||||
|
||||
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder":
|
||||
"""Set default configuration for a node type."""
|
||||
self._config.set_default_config(node_type, config)
|
||||
return self
|
||||
|
||||
def build(self) -> MockConfig:
|
||||
"""Build and return the MockConfig instance."""
|
||||
return self._config
|
||||
@ -0,0 +1,281 @@
|
||||
"""
|
||||
Example demonstrating the auto-mock system for testing workflows.
|
||||
|
||||
This example shows how to test workflows with third-party service nodes
|
||||
without making actual API calls.
|
||||
"""
|
||||
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def example_test_llm_workflow():
|
||||
"""
|
||||
Example: Testing a workflow with an LLM node.
|
||||
|
||||
This demonstrates how to test a workflow that uses an LLM service
|
||||
without making actual API calls to OpenAI, Anthropic, etc.
|
||||
"""
|
||||
print("\n=== Example: Testing LLM Workflow ===\n")
|
||||
|
||||
# Initialize the test runner
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Configure mock responses
|
||||
mock_config = MockConfigBuilder().with_llm_response("I'm a helpful AI assistant. How can I help you today?").build()
|
||||
|
||||
# Define the test case
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="llm-simple",
|
||||
inputs={"query": "Hello, AI!"},
|
||||
expected_outputs={"answer": "I'm a helpful AI assistant. How can I help you today?"},
|
||||
description="Testing LLM workflow with mocked response",
|
||||
use_auto_mock=True, # Enable auto-mocking
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
# Run the test
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
if result.success:
|
||||
print("✅ Test passed!")
|
||||
print(f" Input: {test_case.inputs['query']}")
|
||||
print(f" Output: {result.actual_outputs['answer']}")
|
||||
print(f" Execution time: {result.execution_time:.2f}s")
|
||||
else:
|
||||
print(f"❌ Test failed: {result.error}")
|
||||
|
||||
return result.success
|
||||
|
||||
|
||||
def example_test_with_custom_outputs():
|
||||
"""
|
||||
Example: Testing with custom outputs for specific nodes.
|
||||
|
||||
This shows how to provide different mock outputs for specific node IDs,
|
||||
useful when testing complex workflows with multiple LLM/tool nodes.
|
||||
"""
|
||||
print("\n=== Example: Custom Node Outputs ===\n")
|
||||
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Configure mock with specific outputs for different nodes
|
||||
mock_config = MockConfigBuilder().build()
|
||||
|
||||
# Set custom output for a specific LLM node
|
||||
mock_config.set_node_outputs(
|
||||
"llm_node",
|
||||
{
|
||||
"text": "This is a custom response for the specific LLM node",
|
||||
"usage": {
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 70,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="llm-simple",
|
||||
inputs={"query": "Tell me about custom outputs"},
|
||||
expected_outputs={"answer": "This is a custom response for the specific LLM node"},
|
||||
description="Testing with custom node outputs",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
if result.success:
|
||||
print("✅ Test with custom outputs passed!")
|
||||
print(f" Custom output: {result.actual_outputs['answer']}")
|
||||
else:
|
||||
print(f"❌ Test failed: {result.error}")
|
||||
|
||||
return result.success
|
||||
|
||||
|
||||
def example_test_http_and_tool_workflow():
|
||||
"""
|
||||
Example: Testing a workflow with HTTP request and tool nodes.
|
||||
|
||||
This demonstrates mocking external HTTP calls and tool executions.
|
||||
"""
|
||||
print("\n=== Example: HTTP and Tool Workflow ===\n")
|
||||
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Configure mocks for HTTP and Tool nodes
|
||||
mock_config = MockConfigBuilder().build()
|
||||
|
||||
# Mock HTTP response
|
||||
mock_config.set_node_outputs(
|
||||
"http_node",
|
||||
{
|
||||
"status_code": 200,
|
||||
"body": '{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}',
|
||||
"headers": {"content-type": "application/json"},
|
||||
},
|
||||
)
|
||||
|
||||
# Mock tool response (e.g., JSON parser)
|
||||
mock_config.set_node_outputs(
|
||||
"tool_node",
|
||||
{
|
||||
"result": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]},
|
||||
},
|
||||
)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="http-tool-workflow",
|
||||
inputs={"url": "https://api.example.com/users"},
|
||||
expected_outputs={
|
||||
"status_code": 200,
|
||||
"parsed_data": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]},
|
||||
},
|
||||
description="Testing HTTP and Tool workflow",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
if result.success:
|
||||
print("✅ HTTP and Tool workflow test passed!")
|
||||
print(f" HTTP Status: {result.actual_outputs['status_code']}")
|
||||
print(f" Parsed Data: {result.actual_outputs['parsed_data']}")
|
||||
else:
|
||||
print(f"❌ Test failed: {result.error}")
|
||||
|
||||
return result.success
|
||||
|
||||
|
||||
def example_test_error_simulation():
|
||||
"""
|
||||
Example: Simulating errors in specific nodes.
|
||||
|
||||
This shows how to test error handling in workflows by simulating
|
||||
failures in specific nodes.
|
||||
"""
|
||||
print("\n=== Example: Error Simulation ===\n")
|
||||
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Configure mock to simulate an error
|
||||
mock_config = MockConfigBuilder().build()
|
||||
mock_config.set_node_error("llm_node", "API rate limit exceeded")
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="llm-simple",
|
||||
inputs={"query": "This will fail"},
|
||||
expected_outputs={}, # We expect failure
|
||||
description="Testing error handling",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
if not result.success:
|
||||
print("✅ Error simulation worked as expected!")
|
||||
print(f" Simulated error: {result.error}")
|
||||
else:
|
||||
print("❌ Expected failure but test succeeded")
|
||||
|
||||
return not result.success # Success means we got the expected error
|
||||
|
||||
|
||||
def example_test_with_delays():
|
||||
"""
|
||||
Example: Testing with simulated execution delays.
|
||||
|
||||
This demonstrates how to simulate realistic execution times
|
||||
for performance testing.
|
||||
"""
|
||||
print("\n=== Example: Simulated Delays ===\n")
|
||||
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Configure mock with delays
|
||||
mock_config = (
|
||||
MockConfigBuilder()
|
||||
.with_delays(True) # Enable delay simulation
|
||||
.with_llm_response("Response after delay")
|
||||
.build()
|
||||
)
|
||||
|
||||
# Add specific delay for the LLM node
|
||||
from .test_mock_config import NodeMockConfig
|
||||
|
||||
node_config = NodeMockConfig(
|
||||
node_id="llm_node",
|
||||
outputs={"text": "Response after delay"},
|
||||
delay=0.5, # 500ms delay
|
||||
)
|
||||
mock_config.set_node_config("llm_node", node_config)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="llm-simple",
|
||||
inputs={"query": "Test with delay"},
|
||||
expected_outputs={"answer": "Response after delay"},
|
||||
description="Testing with simulated delays",
|
||||
use_auto_mock=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
if result.success:
|
||||
print("✅ Delay simulation test passed!")
|
||||
print(f" Execution time: {result.execution_time:.2f}s")
|
||||
print(" (Should be >= 0.5s due to simulated delay)")
|
||||
else:
|
||||
print(f"❌ Test failed: {result.error}")
|
||||
|
||||
return result.success and result.execution_time >= 0.5
|
||||
|
||||
|
||||
def run_all_examples():
|
||||
"""Run all example tests."""
|
||||
print("\n" + "=" * 50)
|
||||
print("AUTO-MOCK SYSTEM EXAMPLES")
|
||||
print("=" * 50)
|
||||
|
||||
examples = [
|
||||
example_test_llm_workflow,
|
||||
example_test_with_custom_outputs,
|
||||
example_test_http_and_tool_workflow,
|
||||
example_test_error_simulation,
|
||||
example_test_with_delays,
|
||||
]
|
||||
|
||||
results = []
|
||||
for example in examples:
|
||||
try:
|
||||
results.append(example())
|
||||
except Exception as e:
|
||||
print(f"\n❌ Example failed with exception: {e}")
|
||||
results.append(False)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("SUMMARY")
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(results)
|
||||
total = len(results)
|
||||
print(f"\n✅ Passed: {passed}/{total}")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All examples passed successfully!")
|
||||
else:
|
||||
print(f"\n⚠️ {total - passed} example(s) failed")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
success = run_all_examples()
|
||||
sys.exit(0 if success else 1)
|
||||
@ -0,0 +1,146 @@
|
||||
"""
|
||||
Mock node factory for testing workflows with third-party service dependencies.
|
||||
|
||||
This module provides a MockNodeFactory that automatically detects and mocks nodes
|
||||
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
from .test_mock_nodes import (
|
||||
MockAgentNode,
|
||||
MockCodeNode,
|
||||
MockDocumentExtractorNode,
|
||||
MockHttpRequestNode,
|
||||
MockIterationNode,
|
||||
MockKnowledgeRetrievalNode,
|
||||
MockLLMNode,
|
||||
MockLoopNode,
|
||||
MockParameterExtractorNode,
|
||||
MockQuestionClassifierNode,
|
||||
MockTemplateTransformNode,
|
||||
MockToolNode,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
|
||||
|
||||
class MockNodeFactory(DifyNodeFactory):
|
||||
"""
|
||||
A factory that creates mock nodes for testing purposes.
|
||||
|
||||
This factory intercepts node creation and returns mock implementations
|
||||
for nodes that require third-party services, allowing tests to run
|
||||
without external dependencies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
mock_config: "MockConfig | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the mock node factory.
|
||||
|
||||
:param graph_init_params: Graph initialization parameters
|
||||
:param graph_runtime_state: Graph runtime state
|
||||
:param mock_config: Optional mock configuration for customizing mock behavior
|
||||
"""
|
||||
super().__init__(graph_init_params, graph_runtime_state)
|
||||
self.mock_config = mock_config
|
||||
|
||||
# Map of node types that should be mocked
|
||||
self._mock_node_types = {
|
||||
NodeType.LLM: MockLLMNode,
|
||||
NodeType.AGENT: MockAgentNode,
|
||||
NodeType.TOOL: MockToolNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: MockKnowledgeRetrievalNode,
|
||||
NodeType.HTTP_REQUEST: MockHttpRequestNode,
|
||||
NodeType.QUESTION_CLASSIFIER: MockQuestionClassifierNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: MockParameterExtractorNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: MockDocumentExtractorNode,
|
||||
NodeType.ITERATION: MockIterationNode,
|
||||
NodeType.LOOP: MockLoopNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: MockTemplateTransformNode,
|
||||
NodeType.CODE: MockCodeNode,
|
||||
}
|
||||
|
||||
def create_node(self, node_config: dict[str, Any]) -> Node:
|
||||
"""
|
||||
Create a node instance, using mock implementations for third-party service nodes.
|
||||
|
||||
:param node_config: Node configuration dictionary
|
||||
:return: Node instance (real or mocked)
|
||||
"""
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
|
||||
if not node_type_str:
|
||||
# Fall back to parent implementation for nodes without type
|
||||
return super().create_node(node_config)
|
||||
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
# Unknown node type, use parent implementation
|
||||
return super().create_node(node_config)
|
||||
|
||||
# Check if this node type should be mocked
|
||||
if node_type in self._mock_node_types:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node config missing id")
|
||||
|
||||
# Create mock node instance
|
||||
mock_class = self._mock_node_types[node_type]
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
)
|
||||
|
||||
# Initialize node with provided data
|
||||
mock_instance.init_node_data(node_data)
|
||||
|
||||
return mock_instance
|
||||
|
||||
# For non-mocked node types, use parent implementation
|
||||
return super().create_node(node_config)
|
||||
|
||||
def should_mock_node(self, node_type: NodeType) -> bool:
|
||||
"""
|
||||
Check if a node type should be mocked.
|
||||
|
||||
:param node_type: The node type to check
|
||||
:return: True if the node should be mocked, False otherwise
|
||||
"""
|
||||
return node_type in self._mock_node_types
|
||||
|
||||
def register_mock_node_type(self, node_type: NodeType, mock_class: type[Node]) -> None:
|
||||
"""
|
||||
Register a custom mock implementation for a node type.
|
||||
|
||||
:param node_type: The node type to mock
|
||||
:param mock_class: The mock class to use for this node type
|
||||
"""
|
||||
self._mock_node_types[node_type] = mock_class
|
||||
|
||||
def unregister_mock_node_type(self, node_type: NodeType) -> None:
|
||||
"""
|
||||
Remove a mock implementation for a node type.
|
||||
|
||||
:param node_type: The node type to stop mocking
|
||||
"""
|
||||
if node_type in self._mock_node_types:
|
||||
del self._mock_node_types[node_type]
|
||||
@ -0,0 +1,168 @@
|
||||
"""
|
||||
Simple test to verify MockNodeFactory works with iteration nodes.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add api directory to path
|
||||
api_dir = Path(__file__).parent.parent.parent.parent.parent.parent
|
||||
sys.path.insert(0, str(api_dir))
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory
|
||||
|
||||
|
||||
def test_mock_factory_registers_iteration_node():
|
||||
"""Test that MockNodeFactory has iteration node registered."""
|
||||
|
||||
# Create a MockNodeFactory instance
|
||||
factory = MockNodeFactory(graph_init_params=None, graph_runtime_state=None, mock_config=None)
|
||||
|
||||
# Check that iteration node is registered
|
||||
assert NodeType.ITERATION in factory._mock_node_types
|
||||
print("✓ Iteration node is registered in MockNodeFactory")
|
||||
|
||||
# Check that loop node is registered
|
||||
assert NodeType.LOOP in factory._mock_node_types
|
||||
print("✓ Loop node is registered in MockNodeFactory")
|
||||
|
||||
# Check the class types
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode
|
||||
|
||||
assert factory._mock_node_types[NodeType.ITERATION] == MockIterationNode
|
||||
print("✓ Iteration node maps to MockIterationNode class")
|
||||
|
||||
assert factory._mock_node_types[NodeType.LOOP] == MockLoopNode
|
||||
print("✓ Loop node maps to MockLoopNode class")
|
||||
|
||||
|
||||
def test_mock_iteration_node_preserves_config():
|
||||
"""Test that MockIterationNode preserves mock configuration."""
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode
|
||||
|
||||
# Create mock config
|
||||
mock_config = MockConfigBuilder().with_llm_response("Test response").build()
|
||||
|
||||
# Create minimal graph init params
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create minimal runtime state
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
|
||||
# Create mock iteration node
|
||||
node_config = {
|
||||
"id": "iter1",
|
||||
"data": {
|
||||
"type": "iteration",
|
||||
"title": "Test",
|
||||
"iterator_selector": ["start", "items"],
|
||||
"output_selector": ["node", "text"],
|
||||
"start_node_id": "node1",
|
||||
},
|
||||
}
|
||||
|
||||
mock_node = MockIterationNode(
|
||||
id="iter1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
# Verify the mock config is preserved
|
||||
assert mock_node.mock_config == mock_config
|
||||
print("✓ MockIterationNode preserves mock configuration")
|
||||
|
||||
# Check that _create_graph_engine method exists and is overridden
|
||||
assert hasattr(mock_node, "_create_graph_engine")
|
||||
assert MockIterationNode._create_graph_engine != MockIterationNode.__bases__[1]._create_graph_engine
|
||||
print("✓ MockIterationNode overrides _create_graph_engine method")
|
||||
|
||||
|
||||
def test_mock_loop_node_preserves_config():
|
||||
"""Test that MockLoopNode preserves mock configuration."""
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode
|
||||
|
||||
# Create mock config
|
||||
mock_config = MockConfigBuilder().with_http_response({"status": 200}).build()
|
||||
|
||||
# Create minimal graph init params
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create minimal runtime state
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
|
||||
# Create mock loop node
|
||||
node_config = {
|
||||
"id": "loop1",
|
||||
"data": {
|
||||
"type": "loop",
|
||||
"title": "Test",
|
||||
"loop_count": 3,
|
||||
"start_node_id": "node1",
|
||||
"loop_variables": [],
|
||||
"outputs": {},
|
||||
},
|
||||
}
|
||||
|
||||
mock_node = MockLoopNode(
|
||||
id="loop1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
# Verify the mock config is preserved
|
||||
assert mock_node.mock_config == mock_config
|
||||
print("✓ MockLoopNode preserves mock configuration")
|
||||
|
||||
# Check that _create_graph_engine method exists and is overridden
|
||||
assert hasattr(mock_node, "_create_graph_engine")
|
||||
assert MockLoopNode._create_graph_engine != MockLoopNode.__bases__[1]._create_graph_engine
|
||||
print("✓ MockLoopNode overrides _create_graph_engine method")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mock_factory_registers_iteration_node()
|
||||
test_mock_iteration_node_preserves_config()
|
||||
test_mock_loop_node_preserves_config()
|
||||
print("\n✅ All tests passed! MockNodeFactory now supports iteration and loop nodes.")
|
||||
@ -0,0 +1,829 @@
|
||||
"""
|
||||
Mock node implementations for testing.
|
||||
|
||||
This module provides mock implementations of nodes that require third-party services,
|
||||
allowing tests to run without external dependencies.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.agent import AgentNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
|
||||
|
||||
class MockNodeMixin:
|
||||
"""Mixin providing common mock functionality."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.mock_config = mock_config
|
||||
|
||||
def _get_mock_outputs(self, default_outputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get mock outputs for this node."""
|
||||
if not self.mock_config:
|
||||
return default_outputs
|
||||
|
||||
# Check for node-specific configuration
|
||||
node_config = self.mock_config.get_node_config(self._node_id)
|
||||
if node_config and node_config.outputs:
|
||||
return node_config.outputs
|
||||
|
||||
# Check for custom handler
|
||||
if node_config and node_config.custom_handler:
|
||||
return node_config.custom_handler(self)
|
||||
|
||||
return default_outputs
|
||||
|
||||
def _should_simulate_error(self) -> str | None:
|
||||
"""Check if this node should simulate an error."""
|
||||
if not self.mock_config:
|
||||
return None
|
||||
|
||||
node_config = self.mock_config.get_node_config(self._node_id)
|
||||
if node_config:
|
||||
return node_config.error
|
||||
|
||||
return None
|
||||
|
||||
def _simulate_delay(self) -> None:
|
||||
"""Simulate execution delay if configured."""
|
||||
if not self.mock_config or not self.mock_config.simulate_delays:
|
||||
return
|
||||
|
||||
node_config = self.mock_config.get_node_config(self._node_id)
|
||||
if node_config and node_config.delay > 0:
|
||||
time.sleep(node_config.delay)
|
||||
|
||||
|
||||
class MockLLMNode(MockNodeMixin, LLMNode):
|
||||
"""Mock implementation of LLMNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock LLM node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response
|
||||
default_response = self.mock_config.default_llm_response if self.mock_config else "Mocked LLM response"
|
||||
outputs = self._get_mock_outputs(
|
||||
{
|
||||
"text": default_response,
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
# Simulate streaming if text output exists
|
||||
if "text" in outputs:
|
||||
text = str(outputs["text"])
|
||||
# Split text into words and stream with spaces between them
|
||||
# To match test expectation of text.count(" ") + 2 chunks
|
||||
words = text.split(" ")
|
||||
for i, word in enumerate(words):
|
||||
# Add space before word (except for first word) to reconstruct text properly
|
||||
if i > 0:
|
||||
chunk = " " + word
|
||||
else:
|
||||
chunk = word
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=chunk,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Send final chunk
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Create mock usage with all required fields
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 10)
|
||||
usage.completion_tokens = outputs.get("usage", {}).get("completion_tokens", 5)
|
||||
usage.total_tokens = outputs.get("usage", {}).get("total_tokens", 15)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"mock": "inputs"},
|
||||
process_data={
|
||||
"model_mode": "chat",
|
||||
"prompts": [],
|
||||
"usage": outputs.get("usage", {}),
|
||||
"finish_reason": outputs.get("finish_reason", "stop"),
|
||||
"model_provider": "mock_provider",
|
||||
"model_name": "mock_model",
|
||||
},
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockAgentNode(MockNodeMixin, AgentNode):
|
||||
"""Mock implementation of AgentNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock agent node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response
|
||||
default_response = self.mock_config.default_agent_response if self.mock_config else "Mocked agent response"
|
||||
outputs = self._get_mock_outputs(
|
||||
{
|
||||
"output": default_response,
|
||||
"files": [],
|
||||
}
|
||||
)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"mock": "inputs"},
|
||||
process_data={
|
||||
"agent_log": "Mock agent executed successfully",
|
||||
},
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: "Mock agent log",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockToolNode(MockNodeMixin, ToolNode):
|
||||
"""Mock implementation of ToolNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock tool node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response
|
||||
default_response = (
|
||||
self.mock_config.default_tool_response if self.mock_config else {"result": "mocked tool output"}
|
||||
)
|
||||
outputs = self._get_mock_outputs(default_response)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"mock": "inputs"},
|
||||
process_data={
|
||||
"tool_name": "mock_tool",
|
||||
"tool_parameters": {},
|
||||
},
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: {
|
||||
"tool_name": "mock_tool",
|
||||
"tool_label": "Mock Tool",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
|
||||
"""Mock implementation of KnowledgeRetrievalNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock knowledge retrieval node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response
|
||||
default_response = (
|
||||
self.mock_config.default_retrieval_response if self.mock_config else "Mocked retrieval content"
|
||||
)
|
||||
outputs = self._get_mock_outputs(
|
||||
{
|
||||
"result": [
|
||||
{
|
||||
"content": default_response,
|
||||
"score": 0.95,
|
||||
"metadata": {"source": "mock_source"},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"query": "mock query"},
|
||||
process_data={
|
||||
"retrieval_method": "mock",
|
||||
"documents_count": 1,
|
||||
},
|
||||
outputs=outputs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
|
||||
"""Mock implementation of HttpRequestNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock HTTP request node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response
|
||||
default_response = (
|
||||
self.mock_config.default_http_response
|
||||
if self.mock_config
|
||||
else {
|
||||
"status_code": 200,
|
||||
"body": "mocked response",
|
||||
"headers": {},
|
||||
}
|
||||
)
|
||||
outputs = self._get_mock_outputs(default_response)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"url": "http://mock.url", "method": "GET"},
|
||||
process_data={
|
||||
"request_url": "http://mock.url",
|
||||
"request_method": "GET",
|
||||
},
|
||||
outputs=outputs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
|
||||
"""Mock implementation of QuestionClassifierNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock question classifier node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response - default to first class
|
||||
outputs = self._get_mock_outputs(
|
||||
{
|
||||
"class_name": "class_1",
|
||||
}
|
||||
)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"query": "mock query"},
|
||||
process_data={
|
||||
"classification": outputs.get("class_name", "class_1"),
|
||||
},
|
||||
outputs=outputs,
|
||||
edge_source_handle=outputs.get("class_name", "class_1"), # Branch based on classification
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
|
||||
"""Mock implementation of ParameterExtractorNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock parameter extractor node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response
|
||||
outputs = self._get_mock_outputs(
|
||||
{
|
||||
"parameters": {
|
||||
"param1": "value1",
|
||||
"param2": "value2",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"text": "mock text"},
|
||||
process_data={
|
||||
"extracted_parameters": outputs.get("parameters", {}),
|
||||
},
|
||||
outputs=outputs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
|
||||
"""Mock implementation of DocumentExtractorNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock document extractor node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
process_data={},
|
||||
error_type="MockError",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Get mock response
|
||||
outputs = self._get_mock_outputs(
|
||||
{
|
||||
"text": "Mocked extracted document content",
|
||||
"metadata": {
|
||||
"pages": 1,
|
||||
"format": "mock",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Send completion event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"file": "mock_file.pdf"},
|
||||
process_data={
|
||||
"extraction_method": "mock",
|
||||
},
|
||||
outputs=outputs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
from core.workflow.nodes.iteration import IterationNode
|
||||
from core.workflow.nodes.loop import LoopNode
|
||||
|
||||
|
||||
class MockIterationNode(MockNodeMixin, IterationNode):
|
||||
"""Mock implementation of IterationNode that preserves mock configuration."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
|
||||
# Import our MockNodeFactory instead of DifyNodeFactory
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
invoke_from=self.invoke_from.value,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
# Create a deep copy of the variable pool for each iteration
|
||||
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
|
||||
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool_copy.add([self._node_id, "index"], index)
|
||||
variable_pool_copy.add([self._node_id, "item"], item)
|
||||
|
||||
# Create a new GraphRuntimeState for this iteration
|
||||
graph_runtime_state_copy = GraphRuntimeState(
|
||||
variable_pool=variable_pool_copy,
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
|
||||
# Create a MockNodeFactory with the same mock_config
|
||||
node_factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
mock_config=self.mock_config, # Pass the mock configuration
|
||||
)
|
||||
|
||||
# Initialize the iteration graph with the mock node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
from core.workflow.nodes.iteration.exc import IterationGraphNotFoundError
|
||||
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=iteration_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
return graph_engine
|
||||
|
||||
|
||||
class MockLoopNode(MockNodeMixin, LoopNode):
|
||||
"""Mock implementation of LoopNode that preserves mock configuration."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
|
||||
# Import our MockNodeFactory instead of DifyNodeFactory
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
invoke_from=self.invoke_from.value,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
# Create a new GraphRuntimeState for this iteration
|
||||
graph_runtime_state_copy = GraphRuntimeState(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
start_at=start_at.timestamp(),
|
||||
)
|
||||
|
||||
# Create a MockNodeFactory with the same mock_config
|
||||
node_factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
mock_config=self.mock_config, # Pass the mock configuration
|
||||
)
|
||||
|
||||
# Initialize the loop graph with the mock node factory
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=loop_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
return graph_engine
|
||||
|
||||
|
||||
class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
|
||||
"""Mock implementation of TemplateTransformNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock template transform node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
error_type="MockError",
|
||||
)
|
||||
|
||||
# Get variables from the node data
|
||||
variables: dict[str, Any] = {}
|
||||
if hasattr(self._node_data, "variables"):
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
|
||||
# Check if we have custom mock outputs configured
|
||||
if self.mock_config:
|
||||
node_config = self.mock_config.get_node_config(self._node_id)
|
||||
if node_config and node_config.outputs:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs=node_config.outputs,
|
||||
)
|
||||
|
||||
# Try to actually process the template using Jinja2 directly
|
||||
try:
|
||||
if hasattr(self._node_data, "template"):
|
||||
# Import jinja2 here to avoid dependency issues
|
||||
from jinja2 import Template
|
||||
|
||||
template = Template(self._node_data.template)
|
||||
result_text = template.render(**variables)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result_text}
|
||||
)
|
||||
except Exception as e:
|
||||
# If direct Jinja2 fails, try CodeExecutor as fallback
|
||||
try:
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
|
||||
if hasattr(self._node_data, "template"):
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs={"output": result["result"]},
|
||||
)
|
||||
except Exception:
|
||||
# Both methods failed, fall back to default mock output
|
||||
pass
|
||||
|
||||
# Fall back to default mock output
|
||||
default_response = (
|
||||
self.mock_config.default_template_transform_response if self.mock_config else "mocked template output"
|
||||
)
|
||||
default_outputs = {"output": default_response}
|
||||
outputs = self._get_mock_outputs(default_outputs)
|
||||
|
||||
# Return result
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
|
||||
class MockCodeNode(MockNodeMixin, CodeNode):
|
||||
"""Mock implementation of CodeNode for testing."""
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock code node."""
|
||||
# Simulate delay if configured
|
||||
self._simulate_delay()
|
||||
|
||||
# Check for simulated error
|
||||
error = self._should_simulate_error()
|
||||
if error:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
inputs={},
|
||||
error_type="MockError",
|
||||
)
|
||||
|
||||
# Get mock outputs - use configured outputs or default based on output schema
|
||||
default_outputs = {}
|
||||
if hasattr(self._node_data, "outputs") and self._node_data.outputs:
|
||||
# Generate default outputs based on schema
|
||||
for output_name, output_config in self._node_data.outputs.items():
|
||||
if output_config.type == "string":
|
||||
default_outputs[output_name] = f"mocked_{output_name}"
|
||||
elif output_config.type == "number":
|
||||
default_outputs[output_name] = 42
|
||||
elif output_config.type == "object":
|
||||
default_outputs[output_name] = {"key": "value"}
|
||||
elif output_config.type == "array[string]":
|
||||
default_outputs[output_name] = ["item1", "item2"]
|
||||
elif output_config.type == "array[number]":
|
||||
default_outputs[output_name] = [1, 2, 3]
|
||||
elif output_config.type == "array[object]":
|
||||
default_outputs[output_name] = [{"key": "value1"}, {"key": "value2"}]
|
||||
else:
|
||||
# Default output when no schema is defined
|
||||
default_outputs = (
|
||||
self.mock_config.default_code_response
|
||||
if self.mock_config
|
||||
else {"result": "mocked code execution result"}
|
||||
)
|
||||
|
||||
outputs = self._get_mock_outputs(default_outputs)
|
||||
|
||||
# Return result
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
outputs=outputs,
|
||||
)
|
||||
@ -0,0 +1,607 @@
|
||||
"""
|
||||
Test cases for Mock Template Transform and Code nodes.
|
||||
|
||||
This module tests the functionality of MockTemplateTransformNode and MockCodeNode
|
||||
to ensure they work correctly with the TableTestRunner.
|
||||
"""
|
||||
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode
|
||||
|
||||
|
||||
class TestMockTemplateTransformNode:
|
||||
"""Test cases for MockTemplateTransformNode."""
|
||||
|
||||
def test_mock_template_transform_node_default_output(self):
|
||||
"""Test that MockTemplateTransformNode processes templates with Jinja2."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create mock config
|
||||
mock_config = MockConfig()
|
||||
|
||||
# Create node config
|
||||
node_config = {
|
||||
"id": "template_node_1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "Test Template Transform",
|
||||
"variables": [],
|
||||
"template": "Hello {{ name }}",
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock node
|
||||
mock_node = MockTemplateTransformNode(
|
||||
id="template_node_1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
# Verify results
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "output" in result.outputs
|
||||
# The template "Hello {{ name }}" with no name variable renders as "Hello "
|
||||
assert result.outputs["output"] == "Hello "
|
||||
|
||||
def test_mock_template_transform_node_custom_output(self):
|
||||
"""Test that MockTemplateTransformNode returns custom configured output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create mock config with custom output
|
||||
mock_config = (
|
||||
MockConfigBuilder().with_node_output("template_node_1", {"output": "Custom template output"}).build()
|
||||
)
|
||||
|
||||
# Create node config
|
||||
node_config = {
|
||||
"id": "template_node_1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "Test Template Transform",
|
||||
"variables": [],
|
||||
"template": "Hello {{ name }}",
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock node
|
||||
mock_node = MockTemplateTransformNode(
|
||||
id="template_node_1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
# Verify results
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "output" in result.outputs
|
||||
assert result.outputs["output"] == "Custom template output"
|
||||
|
||||
def test_mock_template_transform_node_error_simulation(self):
|
||||
"""Test that MockTemplateTransformNode can simulate errors."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create mock config with error
|
||||
mock_config = MockConfigBuilder().with_node_error("template_node_1", "Simulated template error").build()
|
||||
|
||||
# Create node config
|
||||
node_config = {
|
||||
"id": "template_node_1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "Test Template Transform",
|
||||
"variables": [],
|
||||
"template": "Hello {{ name }}",
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock node
|
||||
mock_node = MockTemplateTransformNode(
|
||||
id="template_node_1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
# Verify results
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Simulated template error"
|
||||
|
||||
def test_mock_template_transform_node_with_variables(self):
|
||||
"""Test that MockTemplateTransformNode processes templates with variables."""
|
||||
from core.variables import StringVariable
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Add a variable to the pool
|
||||
variable_pool.add(["test", "name"], StringVariable(name="name", value="World", selector=["test", "name"]))
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create mock config
|
||||
mock_config = MockConfig()
|
||||
|
||||
# Create node config with a variable
|
||||
node_config = {
|
||||
"id": "template_node_1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "Test Template Transform",
|
||||
"variables": [{"variable": "name", "value_selector": ["test", "name"]}],
|
||||
"template": "Hello {{ name }}!",
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock node
|
||||
mock_node = MockTemplateTransformNode(
|
||||
id="template_node_1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
# Verify results
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "output" in result.outputs
|
||||
assert result.outputs["output"] == "Hello World!"
|
||||
|
||||
|
||||
class TestMockCodeNode:
|
||||
"""Test cases for MockCodeNode."""
|
||||
|
||||
def test_mock_code_node_default_output(self):
|
||||
"""Test that MockCodeNode returns default output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create mock config
|
||||
mock_config = MockConfig()
|
||||
|
||||
# Create node config
|
||||
node_config = {
|
||||
"id": "code_node_1",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"title": "Test Code",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "result = 'test'",
|
||||
"outputs": {}, # Empty outputs for default case
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock node
|
||||
mock_node = MockCodeNode(
|
||||
id="code_node_1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
# Verify results
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "result" in result.outputs
|
||||
assert result.outputs["result"] == "mocked code execution result"
|
||||
|
||||
def test_mock_code_node_with_output_schema(self):
|
||||
"""Test that MockCodeNode generates outputs based on schema."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create mock config
|
||||
mock_config = MockConfig()
|
||||
|
||||
# Create node config with output schema
|
||||
node_config = {
|
||||
"id": "code_node_1",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"title": "Test Code",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "name = 'test'\ncount = 42\nitems = ['a', 'b']",
|
||||
"outputs": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
"items": {"type": "array[string]"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock node
|
||||
mock_node = MockCodeNode(
|
||||
id="code_node_1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
# Verify results
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "name" in result.outputs
|
||||
assert result.outputs["name"] == "mocked_name"
|
||||
assert "count" in result.outputs
|
||||
assert result.outputs["count"] == 42
|
||||
assert "items" in result.outputs
|
||||
assert result.outputs["items"] == ["item1", "item2"]
|
||||
|
||||
def test_mock_code_node_custom_output(self):
|
||||
"""Test that MockCodeNode returns custom configured output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create mock config with custom output
|
||||
mock_config = (
|
||||
MockConfigBuilder()
|
||||
.with_node_output("code_node_1", {"result": "Custom code result", "status": "success"})
|
||||
.build()
|
||||
)
|
||||
|
||||
# Create node config
|
||||
node_config = {
|
||||
"id": "code_node_1",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"title": "Test Code",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "result = 'test'",
|
||||
"outputs": {}, # Empty outputs for default case
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock node
|
||||
mock_node = MockCodeNode(
|
||||
id="code_node_1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
# Verify results
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "result" in result.outputs
|
||||
assert result.outputs["result"] == "Custom code result"
|
||||
assert "status" in result.outputs
|
||||
assert result.outputs["status"] == "success"
|
||||
|
||||
|
||||
class TestMockNodeFactory:
|
||||
"""Test cases for MockNodeFactory with new node types."""
|
||||
|
||||
def test_code_and_template_nodes_mocked_by_default(self):
|
||||
"""Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy)."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create factory
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy)
|
||||
assert factory.should_mock_node(NodeType.CODE)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
# Verify that other third-party service nodes ARE also mocked by default
|
||||
assert factory.should_mock_node(NodeType.LLM)
|
||||
assert factory.should_mock_node(NodeType.AGENT)
|
||||
|
||||
def test_factory_creates_mock_template_transform_node(self):
|
||||
"""Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create factory
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Create node config
|
||||
node_config = {
|
||||
"id": "template_node_1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "Test Template",
|
||||
"variables": [],
|
||||
"template": "Hello {{ name }}",
|
||||
},
|
||||
}
|
||||
|
||||
# Create node through factory
|
||||
node = factory.create_node(node_config)
|
||||
|
||||
# Verify the correct mock type was created
|
||||
assert isinstance(node, MockTemplateTransformNode)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
def test_factory_creates_mock_code_node(self):
|
||||
"""Test that MockNodeFactory creates MockCodeNode for code type."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create factory
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Create node config
|
||||
node_config = {
|
||||
"id": "code_node_1",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"title": "Test Code",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "result = 42",
|
||||
"outputs": {}, # Required field for CodeNodeData
|
||||
},
|
||||
}
|
||||
|
||||
# Create node through factory
|
||||
node = factory.create_node(node_config)
|
||||
|
||||
# Verify the correct mock type was created
|
||||
assert isinstance(node, MockCodeNode)
|
||||
assert factory.should_mock_node(NodeType.CODE)
|
||||
@ -0,0 +1,187 @@
|
||||
"""
|
||||
Simple test to validate the auto-mock system without external dependencies.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add api directory to path
|
||||
api_dir = Path(__file__).parent.parent.parent.parent.parent.parent
|
||||
sys.path.insert(0, str(api_dir))
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory
|
||||
|
||||
|
||||
def test_mock_config_builder():
|
||||
"""Test the MockConfigBuilder fluent interface."""
|
||||
print("Testing MockConfigBuilder...")
|
||||
|
||||
config = (
|
||||
MockConfigBuilder()
|
||||
.with_llm_response("LLM response")
|
||||
.with_agent_response("Agent response")
|
||||
.with_tool_response({"tool": "output"})
|
||||
.with_retrieval_response("Retrieval content")
|
||||
.with_http_response({"status_code": 201, "body": "created"})
|
||||
.with_node_output("node1", {"output": "value"})
|
||||
.with_node_error("node2", "error message")
|
||||
.with_delays(True)
|
||||
.build()
|
||||
)
|
||||
|
||||
assert config.default_llm_response == "LLM response"
|
||||
assert config.default_agent_response == "Agent response"
|
||||
assert config.default_tool_response == {"tool": "output"}
|
||||
assert config.default_retrieval_response == "Retrieval content"
|
||||
assert config.default_http_response == {"status_code": 201, "body": "created"}
|
||||
assert config.simulate_delays is True
|
||||
|
||||
node1_config = config.get_node_config("node1")
|
||||
assert node1_config is not None
|
||||
assert node1_config.outputs == {"output": "value"}
|
||||
|
||||
node2_config = config.get_node_config("node2")
|
||||
assert node2_config is not None
|
||||
assert node2_config.error == "error message"
|
||||
|
||||
print("✓ MockConfigBuilder test passed")
|
||||
|
||||
|
||||
def test_mock_config_operations():
|
||||
"""Test MockConfig operations."""
|
||||
print("Testing MockConfig operations...")
|
||||
|
||||
config = MockConfig()
|
||||
|
||||
# Test setting node outputs
|
||||
config.set_node_outputs("test_node", {"result": "test_value"})
|
||||
node_config = config.get_node_config("test_node")
|
||||
assert node_config is not None
|
||||
assert node_config.outputs == {"result": "test_value"}
|
||||
|
||||
# Test setting node error
|
||||
config.set_node_error("error_node", "Test error")
|
||||
error_config = config.get_node_config("error_node")
|
||||
assert error_config is not None
|
||||
assert error_config.error == "Test error"
|
||||
|
||||
# Test default configs by node type
|
||||
config.set_default_config(NodeType.LLM, {"temperature": 0.7})
|
||||
llm_config = config.get_default_config(NodeType.LLM)
|
||||
assert llm_config == {"temperature": 0.7}
|
||||
|
||||
print("✓ MockConfig operations test passed")
|
||||
|
||||
|
||||
def test_node_mock_config():
|
||||
"""Test NodeMockConfig."""
|
||||
print("Testing NodeMockConfig...")
|
||||
|
||||
# Test with custom handler
|
||||
def custom_handler(node):
|
||||
return {"custom": "output"}
|
||||
|
||||
node_config = NodeMockConfig(
|
||||
node_id="test_node", outputs={"text": "test"}, error=None, delay=0.5, custom_handler=custom_handler
|
||||
)
|
||||
|
||||
assert node_config.node_id == "test_node"
|
||||
assert node_config.outputs == {"text": "test"}
|
||||
assert node_config.delay == 0.5
|
||||
assert node_config.custom_handler is not None
|
||||
|
||||
# Test custom handler
|
||||
result = node_config.custom_handler(None)
|
||||
assert result == {"custom": "output"}
|
||||
|
||||
print("✓ NodeMockConfig test passed")
|
||||
|
||||
|
||||
def test_mock_factory_detection():
|
||||
"""Test MockNodeFactory node type detection."""
|
||||
print("Testing MockNodeFactory detection...")
|
||||
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
# Test that third-party service nodes are identified for mocking
|
||||
assert factory.should_mock_node(NodeType.LLM)
|
||||
assert factory.should_mock_node(NodeType.AGENT)
|
||||
assert factory.should_mock_node(NodeType.TOOL)
|
||||
assert factory.should_mock_node(NodeType.KNOWLEDGE_RETRIEVAL)
|
||||
assert factory.should_mock_node(NodeType.HTTP_REQUEST)
|
||||
assert factory.should_mock_node(NodeType.PARAMETER_EXTRACTOR)
|
||||
assert factory.should_mock_node(NodeType.DOCUMENT_EXTRACTOR)
|
||||
|
||||
# Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy)
|
||||
assert factory.should_mock_node(NodeType.CODE)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
# Test that non-service nodes are not mocked
|
||||
assert not factory.should_mock_node(NodeType.START)
|
||||
assert not factory.should_mock_node(NodeType.END)
|
||||
assert not factory.should_mock_node(NodeType.IF_ELSE)
|
||||
assert not factory.should_mock_node(NodeType.VARIABLE_AGGREGATOR)
|
||||
|
||||
print("✓ MockNodeFactory detection test passed")
|
||||
|
||||
|
||||
def test_mock_factory_registration():
|
||||
"""Test registering and unregistering mock node types."""
|
||||
print("Testing MockNodeFactory registration...")
|
||||
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
# TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
# Unregister mock
|
||||
factory.unregister_mock_node_type(NodeType.TEMPLATE_TRANSFORM)
|
||||
assert not factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
# Register custom mock (using a dummy class for testing)
|
||||
class DummyMockNode:
|
||||
pass
|
||||
|
||||
factory.register_mock_node_type(NodeType.TEMPLATE_TRANSFORM, DummyMockNode)
|
||||
assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM)
|
||||
|
||||
print("✓ MockNodeFactory registration test passed")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests."""
|
||||
print("\n=== Running Auto-Mock System Tests ===\n")
|
||||
|
||||
try:
|
||||
test_mock_config_builder()
|
||||
test_mock_config_operations()
|
||||
test_node_mock_config()
|
||||
test_mock_factory_detection()
|
||||
test_mock_factory_registration()
|
||||
|
||||
print("\n=== All tests passed! ✅ ===\n")
|
||||
return True
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@ -0,0 +1,273 @@
|
||||
"""
|
||||
Test for parallel streaming workflow behavior.
|
||||
|
||||
This test validates that:
|
||||
- LLM 1 always speaks English
|
||||
- LLM 2 always speaks Chinese
|
||||
- 2 LLMs run parallel, but LLM 2 will output before LLM 1
|
||||
- All chunks should be sent before Answer Node started
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_table_runner import TableTestRunner
|
||||
|
||||
|
||||
def create_llm_generator_with_delay(chunks: list[str], delay: float = 0.1):
|
||||
"""Create a generator that simulates LLM streaming output with delay"""
|
||||
|
||||
def llm_generator(self):
|
||||
for i, chunk in enumerate(chunks):
|
||||
time.sleep(delay) # Simulate network delay
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=str(uuid4()),
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
selector=[self.id, "text"],
|
||||
chunk=chunk,
|
||||
is_final=i == len(chunks) - 1,
|
||||
)
|
||||
|
||||
# Complete response
|
||||
full_text = "".join(chunks)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": full_text},
|
||||
)
|
||||
)
|
||||
|
||||
return llm_generator
|
||||
|
||||
|
||||
def test_parallel_streaming_workflow():
|
||||
"""
|
||||
Test parallel streaming workflow to verify:
|
||||
1. All chunks from LLM 2 are output before LLM 1
|
||||
2. At least one chunk from LLM 2 is output before LLM 1 completes (Success)
|
||||
3. At least one chunk from LLM 1 is output before LLM 2 completes (EXPECTED TO FAIL)
|
||||
4. All chunks are output before End begins
|
||||
5. The final output content matches the order defined in the Answer
|
||||
|
||||
Test setup:
|
||||
- LLM 1 outputs English (slower)
|
||||
- LLM 2 outputs Chinese (faster)
|
||||
- Both run in parallel
|
||||
|
||||
This test is expected to FAIL because chunks are currently buffered
|
||||
until after node completion instead of streaming during execution.
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Load the workflow configuration
|
||||
fixture_data = runner.workflow_runner.load_fixture("multilingual_parallel_llm_streaming_workflow")
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
graph_config = workflow_config.get("graph", {})
|
||||
|
||||
# Create graph initialization parameters
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create variable pool with system variables
|
||||
system_variables = SystemVariable(
|
||||
user_id=init_params.user_id,
|
||||
app_id=init_params.app_id,
|
||||
workflow_id=init_params.workflow_id,
|
||||
files=[],
|
||||
query="Tell me about yourself", # User query
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_variables,
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Create graph runtime state
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory and graph
|
||||
node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
# Create the graph engine
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Define LLM outputs
|
||||
llm1_chunks = ["Hello", ", ", "I", " ", "am", " ", "an", " ", "AI", " ", "assistant", "."] # English (slower)
|
||||
llm2_chunks = ["你好", ",", "我", "是", "AI", "助手", "。"] # Chinese (faster)
|
||||
|
||||
# Create generators with different delays (LLM 2 is faster)
|
||||
llm1_generator = create_llm_generator_with_delay(llm1_chunks, delay=0.05) # Slower
|
||||
llm2_generator = create_llm_generator_with_delay(llm2_chunks, delay=0.01) # Faster
|
||||
|
||||
# Track which LLM node is being called
|
||||
llm_call_order = []
|
||||
generators = {
|
||||
"1754339718571": llm1_generator, # LLM 1 node ID
|
||||
"1754339725656": llm2_generator, # LLM 2 node ID
|
||||
}
|
||||
|
||||
def mock_llm_run(self):
|
||||
llm_call_order.append(self.id)
|
||||
generator = generators.get(self.id)
|
||||
if generator:
|
||||
yield from generator(self)
|
||||
else:
|
||||
raise Exception(f"Unexpected LLM node ID: {self.id}")
|
||||
|
||||
# Execute with mocked LLMs
|
||||
with patch.object(LLMNode, "_run", new=mock_llm_run):
|
||||
events = list(engine.run())
|
||||
|
||||
# Check for successful completion
|
||||
success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)]
|
||||
assert len(success_events) > 0, "Workflow should complete successfully"
|
||||
|
||||
# Get all streaming chunk events
|
||||
stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)]
|
||||
|
||||
# Get Answer node start event
|
||||
answer_start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.ANSWER]
|
||||
assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}"
|
||||
answer_start_event = answer_start_events[0]
|
||||
|
||||
# Find the index of Answer node start
|
||||
answer_start_index = events.index(answer_start_event)
|
||||
|
||||
# Collect chunk events by node
|
||||
llm1_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339718571"]
|
||||
llm2_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339725656"]
|
||||
|
||||
# Verify both LLMs produced chunks
|
||||
assert len(llm1_chunks_events) == len(llm1_chunks), (
|
||||
f"Expected {len(llm1_chunks)} chunks from LLM 1, got {len(llm1_chunks_events)}"
|
||||
)
|
||||
assert len(llm2_chunks_events) == len(llm2_chunks), (
|
||||
f"Expected {len(llm2_chunks)} chunks from LLM 2, got {len(llm2_chunks_events)}"
|
||||
)
|
||||
|
||||
# 1. Verify chunk ordering based on actual implementation
|
||||
llm1_chunk_indices = [events.index(e) for e in llm1_chunks_events]
|
||||
llm2_chunk_indices = [events.index(e) for e in llm2_chunks_events]
|
||||
|
||||
# In the current implementation, chunks may be interleaved or in a specific order
|
||||
# Update this based on actual behavior observed
|
||||
if llm1_chunk_indices and llm2_chunk_indices:
|
||||
# Check the actual ordering - if LLM 2 chunks come first (as seen in debug)
|
||||
assert max(llm2_chunk_indices) < min(llm1_chunk_indices), (
|
||||
f"All LLM 2 chunks should be output before LLM 1 chunks. "
|
||||
f"LLM 2 chunk indices: {llm2_chunk_indices}, LLM 1 chunk indices: {llm1_chunk_indices}"
|
||||
)
|
||||
|
||||
# Get indices of all chunk events
|
||||
chunk_indices = [events.index(e) for e in stream_chunk_events if e in llm1_chunks_events + llm2_chunks_events]
|
||||
|
||||
# 4. Verify all chunks were sent before Answer node started
|
||||
assert all(idx < answer_start_index for idx in chunk_indices), (
|
||||
"All LLM chunks should be sent before Answer node starts"
|
||||
)
|
||||
|
||||
# The test has successfully verified:
|
||||
# 1. Both LLMs run in parallel (they start at the same time)
|
||||
# 2. LLM 2 (Chinese) outputs all its chunks before LLM 1 (English) due to faster processing
|
||||
# 3. All LLM chunks are sent before the Answer node starts
|
||||
|
||||
# Get LLM completion events
|
||||
llm_completed_events = [
|
||||
(i, e) for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM
|
||||
]
|
||||
|
||||
# Check LLM completion order - in the current implementation, LLMs run sequentially
|
||||
# LLM 1 completes first, then LLM 2 runs and completes
|
||||
assert len(llm_completed_events) == 2, f"Expected 2 LLM completion events, got {len(llm_completed_events)}"
|
||||
llm2_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339725656"), None)
|
||||
llm1_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339718571"), None)
|
||||
assert llm2_complete_idx is not None, "LLM 2 completion event not found"
|
||||
assert llm1_complete_idx is not None, "LLM 1 completion event not found"
|
||||
# In the actual implementation, LLM 1 completes before LLM 2 (sequential execution)
|
||||
assert llm1_complete_idx < llm2_complete_idx, (
|
||||
f"LLM 1 should complete before LLM 2 in sequential execution, but LLM 1 completed at {llm1_complete_idx} "
|
||||
f"and LLM 2 completed at {llm2_complete_idx}"
|
||||
)
|
||||
|
||||
# 2. In sequential execution, LLM 2 chunks appear AFTER LLM 1 completes
|
||||
if llm2_chunk_indices:
|
||||
# LLM 1 completes first, then LLM 2 starts streaming
|
||||
assert min(llm2_chunk_indices) > llm1_complete_idx, (
|
||||
f"LLM 2 chunks should appear after LLM 1 completes in sequential execution. "
|
||||
f"First LLM 2 chunk at index {min(llm2_chunk_indices)}, LLM 1 completed at index {llm1_complete_idx}"
|
||||
)
|
||||
|
||||
# 3. In the current implementation, LLM 1 chunks appear after LLM 2 completes
|
||||
# This is because chunks are buffered and output after both nodes complete
|
||||
if llm1_chunk_indices and llm2_complete_idx:
|
||||
# Check if LLM 1 chunks exist and where they appear relative to LLM 2 completion
|
||||
# In current behavior, LLM 1 chunks typically appear after LLM 2 completes
|
||||
pass # Skipping this check as the chunk ordering is implementation-dependent
|
||||
|
||||
# CURRENT BEHAVIOR: Chunks are buffered and appear after node completion
|
||||
# In the sequential execution, LLM 1 completes first without streaming,
|
||||
# then LLM 2 streams its chunks
|
||||
assert stream_chunk_events, "Expected streaming events, but got none"
|
||||
|
||||
first_chunk_index = events.index(stream_chunk_events[0])
|
||||
llm_success_indices = [i for i, e in llm_completed_events]
|
||||
|
||||
# Current implementation: LLM 1 completes first, then chunks start appearing
|
||||
# This is the actual behavior we're testing
|
||||
if llm_success_indices:
|
||||
# At least one LLM (LLM 1) completes before any chunks appear
|
||||
assert min(llm_success_indices) < first_chunk_index, (
|
||||
f"In current implementation, LLM 1 completes before chunks start streaming. "
|
||||
f"First chunk at index {first_chunk_index}, LLM 1 completed at index {min(llm_success_indices)}"
|
||||
)
|
||||
|
||||
# 5. Verify final output content matches the order defined in Answer node
|
||||
# According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}'
|
||||
# This means LLM 2 output should come first, then LLM 1 output
|
||||
answer_complete_events = [
|
||||
e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.ANSWER
|
||||
]
|
||||
assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}"
|
||||
|
||||
answer_outputs = answer_complete_events[0].node_run_result.outputs
|
||||
expected_answer_text = "你好,我是AI助手。Hello, I am an AI assistant."
|
||||
|
||||
if "answer" in answer_outputs:
|
||||
actual_answer_text = answer_outputs["answer"]
|
||||
assert actual_answer_text == expected_answer_text, (
|
||||
f"Answer content should match the order defined in Answer node. "
|
||||
f"Expected: '{expected_answer_text}', Got: '{actual_answer_text}'"
|
||||
)
|
||||
@ -0,0 +1,215 @@
|
||||
"""
|
||||
Unit tests for Redis-based stop functionality in GraphEngine.
|
||||
|
||||
Tests the integration of Redis command channel for stopping workflows
|
||||
without user permission checks.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
|
||||
|
||||
class TestRedisStopIntegration:
|
||||
"""Test suite for Redis-based workflow stop functionality."""
|
||||
|
||||
def test_graph_engine_manager_sends_abort_command(self):
|
||||
"""Test that GraphEngineManager correctly sends abort command through Redis."""
|
||||
# Setup
|
||||
task_id = "test-task-123"
|
||||
expected_channel_key = f"workflow:{task_id}:commands"
|
||||
|
||||
# Mock redis client
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
|
||||
# Execute
|
||||
GraphEngineManager.send_stop_command(task_id, reason="Test stop")
|
||||
|
||||
# Verify
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
|
||||
# Check that rpush was called with correct arguments
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
|
||||
# Verify the channel key
|
||||
assert calls[0][0][0] == expected_channel_key
|
||||
|
||||
# Verify the command data
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT.value
|
||||
assert command_data["reason"] == "Test stop"
|
||||
|
||||
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
|
||||
"""Test that GraphEngineManager handles Redis failures without raising exceptions."""
|
||||
task_id = "test-task-456"
|
||||
|
||||
# Mock redis client to raise exception
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed")
|
||||
|
||||
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
|
||||
# Should not raise exception
|
||||
try:
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
except Exception as e:
|
||||
pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly")
|
||||
|
||||
def test_app_queue_manager_no_user_check(self):
|
||||
"""Test that AppQueueManager.set_stop_flag_no_user_check works without user validation."""
|
||||
task_id = "test-task-789"
|
||||
expected_cache_key = f"generate_task_stopped:{task_id}"
|
||||
|
||||
# Mock redis client
|
||||
mock_redis = MagicMock()
|
||||
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis):
|
||||
# Execute
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# Verify
|
||||
mock_redis.setex.assert_called_once_with(expected_cache_key, 600, 1)
|
||||
|
||||
def test_app_queue_manager_no_user_check_with_empty_task_id(self):
|
||||
"""Test that AppQueueManager.set_stop_flag_no_user_check handles empty task_id."""
|
||||
# Mock redis client
|
||||
mock_redis = MagicMock()
|
||||
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis):
|
||||
# Execute with empty task_id
|
||||
AppQueueManager.set_stop_flag_no_user_check("")
|
||||
|
||||
# Verify redis was not called
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
def test_redis_channel_send_abort_command(self):
|
||||
"""Test RedisChannel correctly serializes and sends AbortCommand."""
|
||||
# Setup
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
channel_key = "workflow:test:commands"
|
||||
channel = RedisChannel(mock_redis, channel_key)
|
||||
|
||||
# Create abort command
|
||||
abort_command = AbortCommand(reason="User requested stop")
|
||||
|
||||
# Execute
|
||||
channel.send_command(abort_command)
|
||||
|
||||
# Verify
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
|
||||
# Check rpush was called
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0][0] == channel_key
|
||||
|
||||
# Verify serialized command
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT.value
|
||||
assert command_data["reason"] == "User requested stop"
|
||||
|
||||
# Check expire was set
|
||||
mock_pipeline.expire.assert_called_once_with(channel_key, 3600)
|
||||
|
||||
def test_redis_channel_fetch_commands(self):
|
||||
"""Test RedisChannel correctly fetches and deserializes commands."""
|
||||
# Setup
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock command data
|
||||
abort_command_json = json.dumps(
|
||||
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
|
||||
)
|
||||
|
||||
# Mock pipeline execute to return commands
|
||||
mock_pipeline.execute.return_value = [
|
||||
[abort_command_json.encode()], # lrange result
|
||||
True, # delete result
|
||||
]
|
||||
|
||||
channel_key = "workflow:test:commands"
|
||||
channel = RedisChannel(mock_redis, channel_key)
|
||||
|
||||
# Execute
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
# Verify
|
||||
assert len(commands) == 1
|
||||
assert isinstance(commands[0], AbortCommand)
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
assert commands[0].reason == "Test abort"
|
||||
|
||||
# Verify Redis operations
|
||||
mock_pipeline.lrange.assert_called_once_with(channel_key, 0, -1)
|
||||
mock_pipeline.delete.assert_called_once_with(channel_key)
|
||||
|
||||
def test_redis_channel_fetch_commands_handles_invalid_json(self):
|
||||
"""Test RedisChannel gracefully handles invalid JSON in commands."""
|
||||
# Setup
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock invalid command data
|
||||
mock_pipeline.execute.return_value = [
|
||||
[b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result
|
||||
True, # delete result
|
||||
]
|
||||
|
||||
channel_key = "workflow:test:commands"
|
||||
channel = RedisChannel(mock_redis, channel_key)
|
||||
|
||||
# Execute
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
# Should return empty list due to invalid commands
|
||||
assert len(commands) == 0
|
||||
|
||||
def test_dual_stop_mechanism_compatibility(self):
|
||||
"""Test that both stop mechanisms can work together."""
|
||||
task_id = "test-task-dual"
|
||||
|
||||
# Mock redis client
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis),
|
||||
patch("core.workflow.graph_engine.manager.redis_client", mock_redis),
|
||||
):
|
||||
# Execute both stop mechanisms
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
# Verify legacy stop flag was set
|
||||
expected_stop_flag_key = f"generate_task_stopped:{task_id}"
|
||||
mock_redis.setex.assert_called_once_with(expected_stop_flag_key, 600, 1)
|
||||
|
||||
# Verify command was sent through Redis channel
|
||||
mock_redis.pipeline.assert_called()
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0][0] == f"workflow:{task_id}:commands"
|
||||
@ -0,0 +1,47 @@
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_streaming_conversation_variables():
|
||||
fixture_name = "test_streaming_conversation_variables"
|
||||
|
||||
# The test expects the workflow to output the input query
|
||||
# Since the workflow assigns sys.query to conversation variable "str" and then answers with it
|
||||
input_query = "Hello, this is my test query"
|
||||
|
||||
mock_config = MockConfigBuilder().build()
|
||||
|
||||
case = WorkflowTestCase(
|
||||
fixture_path=fixture_name,
|
||||
use_auto_mock=False, # Don't use auto mock since we want to test actual variable assignment
|
||||
mock_config=mock_config,
|
||||
query=input_query, # Pass query as the sys.query value
|
||||
inputs={}, # No additional inputs needed
|
||||
expected_outputs={"answer": input_query}, # Expecting the input query to be output
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent,
|
||||
# START node
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Variable Assigner node
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# ANSWER node
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
)
|
||||
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
@ -0,0 +1,704 @@
|
||||
"""
|
||||
Table-driven test framework for GraphEngine workflows.
|
||||
|
||||
This module provides a robust table-driven testing framework with support for:
|
||||
- Parallel test execution
|
||||
- Property-based testing with Hypothesis
|
||||
- Event sequence validation
|
||||
- Mock configuration
|
||||
- Performance metrics
|
||||
- Detailed error reporting
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from core.tools.utils.yaml_utils import _load_yaml_file
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
StringVariable,
|
||||
)
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowTestCase:
|
||||
"""Represents a single test case for table-driven testing."""
|
||||
|
||||
fixture_path: str
|
||||
expected_outputs: dict[str, Any]
|
||||
inputs: dict[str, Any] = field(default_factory=dict)
|
||||
query: str = ""
|
||||
description: str = ""
|
||||
timeout: float = 30.0
|
||||
mock_config: MockConfig | None = None
|
||||
use_auto_mock: bool = False
|
||||
expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
skip: bool = False
|
||||
skip_reason: str = ""
|
||||
retry_count: int = 0
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowTestResult:
|
||||
"""Result of executing a single test case."""
|
||||
|
||||
test_case: WorkflowTestCase
|
||||
success: bool
|
||||
error: Exception | None = None
|
||||
actual_outputs: dict[str, Any] | None = None
|
||||
execution_time: float = 0.0
|
||||
event_sequence_match: bool | None = None
|
||||
event_mismatch_details: str | None = None
|
||||
events: list[GraphEngineEvent] = field(default_factory=list)
|
||||
retry_attempts: int = 0
|
||||
validation_details: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestSuiteResult:
|
||||
"""Aggregated results for a test suite."""
|
||||
|
||||
total_tests: int
|
||||
passed_tests: int
|
||||
failed_tests: int
|
||||
skipped_tests: int
|
||||
total_execution_time: float
|
||||
results: list[WorkflowTestResult]
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate the success rate of the test suite."""
|
||||
if self.total_tests == 0:
|
||||
return 0.0
|
||||
return (self.passed_tests / self.total_tests) * 100
|
||||
|
||||
def get_failed_results(self) -> list[WorkflowTestResult]:
|
||||
"""Get all failed test results."""
|
||||
return [r for r in self.results if not r.success]
|
||||
|
||||
def get_results_by_tag(self, tag: str) -> list[WorkflowTestResult]:
|
||||
"""Get test results filtered by tag."""
|
||||
return [r for r in self.results if tag in r.test_case.tags]
|
||||
|
||||
|
||||
class WorkflowRunner:
|
||||
"""Core workflow execution engine for tests."""
|
||||
|
||||
def __init__(self, fixtures_dir: Path | None = None):
|
||||
"""Initialize the workflow runner."""
|
||||
if fixtures_dir is None:
|
||||
# Use the new central fixtures location
|
||||
# Navigate from current file to api/tests directory
|
||||
current_file = Path(__file__).resolve()
|
||||
# Find the 'api' directory by traversing up
|
||||
for parent in current_file.parents:
|
||||
if parent.name == "api" and (parent / "tests").exists():
|
||||
fixtures_dir = parent / "tests" / "fixtures" / "workflow"
|
||||
break
|
||||
else:
|
||||
# Fallback if structure is not as expected
|
||||
raise ValueError("Could not locate api/tests/fixtures/workflow directory")
|
||||
|
||||
self.fixtures_dir = Path(fixtures_dir)
|
||||
if not self.fixtures_dir.exists():
|
||||
raise ValueError(f"Fixtures directory does not exist: {self.fixtures_dir}")
|
||||
|
||||
def load_fixture(self, fixture_name: str) -> dict[str, Any]:
|
||||
"""Load a YAML fixture file with caching to avoid repeated parsing."""
|
||||
if not fixture_name.endswith(".yml") and not fixture_name.endswith(".yaml"):
|
||||
fixture_name = f"{fixture_name}.yml"
|
||||
|
||||
fixture_path = self.fixtures_dir / fixture_name
|
||||
return _load_fixture(fixture_path, fixture_name)
|
||||
|
||||
def create_graph_from_fixture(
|
||||
self,
|
||||
fixture_data: dict[str, Any],
|
||||
query: str = "",
|
||||
inputs: dict[str, Any] | None = None,
|
||||
use_mock_factory: bool = False,
|
||||
mock_config: MockConfig | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
"""Create a Graph instance from fixture data."""
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
graph_config = workflow_config.get("graph", {})
|
||||
|
||||
if not graph_config:
|
||||
raise ValueError("Fixture missing workflow.graph configuration")
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="test_user",
|
||||
user_from="account",
|
||||
invoke_from="debugger", # Set to debugger to avoid conversation_id requirement
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
system_variables = SystemVariable(
|
||||
user_id=graph_init_params.user_id,
|
||||
app_id=graph_init_params.app_id,
|
||||
workflow_id=graph_init_params.workflow_id,
|
||||
files=[],
|
||||
query=query,
|
||||
)
|
||||
user_inputs = inputs if inputs is not None else {}
|
||||
|
||||
# Extract conversation variables from workflow config
|
||||
conversation_variables = []
|
||||
conversation_var_configs = workflow_config.get("conversation_variables", [])
|
||||
|
||||
# Mapping from value_type to Variable class
|
||||
variable_type_mapping = {
|
||||
"string": StringVariable,
|
||||
"number": FloatVariable,
|
||||
"integer": IntegerVariable,
|
||||
"object": ObjectVariable,
|
||||
"array[string]": ArrayStringVariable,
|
||||
"array[number]": ArrayNumberVariable,
|
||||
"array[object]": ArrayObjectVariable,
|
||||
}
|
||||
|
||||
for var_config in conversation_var_configs:
|
||||
value_type = var_config.get("value_type", "string")
|
||||
variable_class = variable_type_mapping.get(value_type, StringVariable)
|
||||
|
||||
# Create the appropriate Variable type based on value_type
|
||||
var = variable_class(
|
||||
selector=tuple(var_config.get("selector", [])),
|
||||
name=var_config.get("name", ""),
|
||||
value=var_config.get("value", ""),
|
||||
)
|
||||
conversation_variables.append(var)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_variables,
|
||||
user_inputs=user_inputs,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
if use_mock_factory:
|
||||
node_factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config
|
||||
)
|
||||
else:
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
return graph, graph_runtime_state
|
||||
|
||||
|
||||
class TableTestRunner:
|
||||
"""
|
||||
Advanced table-driven test runner for workflow testing.
|
||||
|
||||
Features:
|
||||
- Parallel test execution
|
||||
- Retry mechanism for flaky tests
|
||||
- Custom validators
|
||||
- Performance profiling
|
||||
- Detailed error reporting
|
||||
- Tag-based filtering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fixtures_dir: Path | None = None,
|
||||
max_workers: int = 4,
|
||||
enable_logging: bool = False,
|
||||
log_level: str = "INFO",
|
||||
graph_engine_min_workers: int = 1,
|
||||
graph_engine_max_workers: int = 1,
|
||||
graph_engine_scale_up_threshold: int = 5,
|
||||
graph_engine_scale_down_idle_time: float = 30.0,
|
||||
):
|
||||
"""
|
||||
Initialize the table test runner.
|
||||
|
||||
Args:
|
||||
fixtures_dir: Directory containing fixture files
|
||||
max_workers: Maximum number of parallel workers for test execution
|
||||
enable_logging: Enable detailed logging
|
||||
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
graph_engine_min_workers: Minimum workers for GraphEngine (default: 1)
|
||||
graph_engine_max_workers: Maximum workers for GraphEngine (default: 1)
|
||||
graph_engine_scale_up_threshold: Queue depth to trigger scale up
|
||||
graph_engine_scale_down_idle_time: Idle time before scaling down
|
||||
"""
|
||||
self.workflow_runner = WorkflowRunner(fixtures_dir)
|
||||
self.max_workers = max_workers
|
||||
|
||||
# Store GraphEngine worker configuration
|
||||
self.graph_engine_min_workers = graph_engine_min_workers
|
||||
self.graph_engine_max_workers = graph_engine_max_workers
|
||||
self.graph_engine_scale_up_threshold = graph_engine_scale_up_threshold
|
||||
self.graph_engine_scale_down_idle_time = graph_engine_scale_down_idle_time
|
||||
|
||||
if enable_logging:
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
self.logger = logger
|
||||
|
||||
def run_test_case(self, test_case: WorkflowTestCase) -> WorkflowTestResult:
|
||||
"""
|
||||
Execute a single test case with retry support.
|
||||
|
||||
Args:
|
||||
test_case: The test case to execute
|
||||
|
||||
Returns:
|
||||
WorkflowTestResult with execution details
|
||||
"""
|
||||
if test_case.skip:
|
||||
self.logger.info("Skipping test: %s - %s", test_case.description, test_case.skip_reason)
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=True,
|
||||
execution_time=0.0,
|
||||
validation_details=f"Skipped: {test_case.skip_reason}",
|
||||
)
|
||||
|
||||
retry_attempts = 0
|
||||
last_result = None
|
||||
last_error = None
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for attempt in range(test_case.retry_count + 1):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = self._execute_test_case(test_case)
|
||||
last_result = result # Save the last result
|
||||
|
||||
if result.success:
|
||||
result.retry_attempts = retry_attempts
|
||||
self.logger.info("Test passed: %s", test_case.description)
|
||||
return result
|
||||
|
||||
last_error = result.error
|
||||
retry_attempts += 1
|
||||
|
||||
if attempt < test_case.retry_count:
|
||||
self.logger.warning(
|
||||
"Test failed (attempt %d/%d): %s",
|
||||
attempt + 1,
|
||||
test_case.retry_count + 1,
|
||||
test_case.description,
|
||||
)
|
||||
time.sleep(0.5 * (attempt + 1)) # Exponential backoff
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_attempts += 1
|
||||
|
||||
if attempt < test_case.retry_count:
|
||||
self.logger.warning(
|
||||
"Test error (attempt %d/%d): %s - %s",
|
||||
attempt + 1,
|
||||
test_case.retry_count + 1,
|
||||
test_case.description,
|
||||
str(e),
|
||||
)
|
||||
time.sleep(0.5 * (attempt + 1))
|
||||
|
||||
# All retries failed - return the last result if available
|
||||
if last_result:
|
||||
last_result.retry_attempts = retry_attempts
|
||||
self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description)
|
||||
return last_result
|
||||
|
||||
# If no result available (all attempts threw exceptions), create a failure result
|
||||
self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description)
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=False,
|
||||
error=last_error,
|
||||
execution_time=time.perf_counter() - start_time,
|
||||
retry_attempts=retry_attempts,
|
||||
)
|
||||
|
||||
def _execute_test_case(self, test_case: WorkflowTestCase) -> WorkflowTestResult:
|
||||
"""Internal method to execute a single test case."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Load fixture data
|
||||
fixture_data = self.workflow_runner.load_fixture(test_case.fixture_path)
|
||||
|
||||
# Create graph from fixture
|
||||
graph, graph_runtime_state = self.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
inputs=test_case.inputs,
|
||||
query=test_case.query,
|
||||
use_mock_factory=test_case.use_auto_mock,
|
||||
mock_config=test_case.mock_config,
|
||||
)
|
||||
|
||||
# Create and run the engine with configured worker settings
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
min_workers=self.graph_engine_min_workers,
|
||||
max_workers=self.graph_engine_max_workers,
|
||||
scale_up_threshold=self.graph_engine_scale_up_threshold,
|
||||
scale_down_idle_time=self.graph_engine_scale_down_idle_time,
|
||||
)
|
||||
|
||||
# Execute and collect events
|
||||
events = []
|
||||
for event in engine.run():
|
||||
events.append(event)
|
||||
|
||||
# Check execution success
|
||||
has_start = any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)]
|
||||
has_success = len(success_events) > 0
|
||||
|
||||
# Validate event sequence if provided (even for failed workflows)
|
||||
event_sequence_match = None
|
||||
event_mismatch_details = None
|
||||
if test_case.expected_event_sequence is not None:
|
||||
event_sequence_match, event_mismatch_details = self._validate_event_sequence(
|
||||
test_case.expected_event_sequence, events
|
||||
)
|
||||
|
||||
if not (has_start and has_success):
|
||||
# Workflow didn't complete, but we may still want to validate events
|
||||
success = False
|
||||
if test_case.expected_event_sequence is not None:
|
||||
# If event sequence was provided, use that for success determination
|
||||
success = event_sequence_match if event_sequence_match is not None else False
|
||||
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=success,
|
||||
error=Exception("Workflow did not complete successfully"),
|
||||
execution_time=time.perf_counter() - start_time,
|
||||
events=events,
|
||||
event_sequence_match=event_sequence_match,
|
||||
event_mismatch_details=event_mismatch_details,
|
||||
)
|
||||
|
||||
# Get actual outputs
|
||||
success_event = success_events[-1]
|
||||
actual_outputs = success_event.outputs or {}
|
||||
|
||||
# Validate outputs
|
||||
output_success, validation_details = self._validate_outputs(
|
||||
test_case.expected_outputs, actual_outputs, test_case.custom_validator
|
||||
)
|
||||
|
||||
# Overall success requires both output and event sequence validation
|
||||
success = output_success and (event_sequence_match if event_sequence_match is not None else True)
|
||||
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=success,
|
||||
actual_outputs=actual_outputs,
|
||||
execution_time=time.perf_counter() - start_time,
|
||||
event_sequence_match=event_sequence_match,
|
||||
event_mismatch_details=event_mismatch_details,
|
||||
events=events,
|
||||
validation_details=validation_details,
|
||||
error=None if success else Exception(validation_details or event_mismatch_details or "Test failed"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception("Error executing test case: %s", test_case.description)
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=False,
|
||||
error=e,
|
||||
execution_time=time.perf_counter() - start_time,
|
||||
)
|
||||
|
||||
def _validate_outputs(
|
||||
self,
|
||||
expected_outputs: dict[str, Any],
|
||||
actual_outputs: dict[str, Any],
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate actual outputs against expected outputs.
|
||||
|
||||
Returns:
|
||||
tuple: (is_valid, validation_details)
|
||||
"""
|
||||
validation_errors = []
|
||||
|
||||
# Check expected outputs
|
||||
for key, expected_value in expected_outputs.items():
|
||||
if key not in actual_outputs:
|
||||
validation_errors.append(f"Missing expected key: {key}")
|
||||
continue
|
||||
|
||||
actual_value = actual_outputs[key]
|
||||
if actual_value != expected_value:
|
||||
# Format multiline strings for better readability
|
||||
if isinstance(expected_value, str) and "\n" in expected_value:
|
||||
expected_lines = expected_value.splitlines()
|
||||
actual_lines = (
|
||||
actual_value.splitlines() if isinstance(actual_value, str) else str(actual_value).splitlines()
|
||||
)
|
||||
|
||||
validation_errors.append(
|
||||
f"Value mismatch for key '{key}':\n"
|
||||
f" Expected ({len(expected_lines)} lines):\n " + "\n ".join(expected_lines) + "\n"
|
||||
f" Actual ({len(actual_lines)} lines):\n " + "\n ".join(actual_lines)
|
||||
)
|
||||
else:
|
||||
validation_errors.append(
|
||||
f"Value mismatch for key '{key}':\n Expected: {expected_value}\n Actual: {actual_value}"
|
||||
)
|
||||
|
||||
# Apply custom validator if provided
|
||||
if custom_validator:
|
||||
try:
|
||||
if not custom_validator(actual_outputs):
|
||||
validation_errors.append("Custom validator failed")
|
||||
except Exception as e:
|
||||
validation_errors.append(f"Custom validator error: {str(e)}")
|
||||
|
||||
if validation_errors:
|
||||
return False, "\n".join(validation_errors)
|
||||
|
||||
return True, None
|
||||
|
||||
def _validate_event_sequence(
|
||||
self, expected_sequence: list[type[GraphEngineEvent]], actual_events: list[GraphEngineEvent]
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate that actual events match the expected event sequence.
|
||||
|
||||
Returns:
|
||||
tuple: (is_valid, error_message)
|
||||
"""
|
||||
actual_event_types = [type(event) for event in actual_events]
|
||||
|
||||
if len(expected_sequence) != len(actual_event_types):
|
||||
return False, (
|
||||
f"Event count mismatch. Expected {len(expected_sequence)} events, "
|
||||
f"got {len(actual_event_types)} events.\n"
|
||||
f"Expected: {[e.__name__ for e in expected_sequence]}\n"
|
||||
f"Actual: {[e.__name__ for e in actual_event_types]}"
|
||||
)
|
||||
|
||||
for i, (expected_type, actual_type) in enumerate(zip(expected_sequence, actual_event_types)):
|
||||
if expected_type != actual_type:
|
||||
return False, (
|
||||
f"Event mismatch at position {i}. "
|
||||
f"Expected {expected_type.__name__}, got {actual_type.__name__}\n"
|
||||
f"Full expected sequence: {[e.__name__ for e in expected_sequence]}\n"
|
||||
f"Full actual sequence: {[e.__name__ for e in actual_event_types]}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def run_table_tests(
|
||||
self,
|
||||
test_cases: list[WorkflowTestCase],
|
||||
parallel: bool = False,
|
||||
tags_filter: list[str] | None = None,
|
||||
fail_fast: bool = False,
|
||||
) -> TestSuiteResult:
|
||||
"""
|
||||
Run multiple test cases as a table test suite.
|
||||
|
||||
Args:
|
||||
test_cases: List of test cases to execute
|
||||
parallel: Run tests in parallel
|
||||
tags_filter: Only run tests with specified tags
|
||||
fail_fast: Stop execution on first failure
|
||||
|
||||
Returns:
|
||||
TestSuiteResult with aggregated results
|
||||
"""
|
||||
# Filter by tags if specified
|
||||
if tags_filter:
|
||||
test_cases = [tc for tc in test_cases if any(tag in tc.tags for tag in tags_filter)]
|
||||
|
||||
if not test_cases:
|
||||
return TestSuiteResult(
|
||||
total_tests=0,
|
||||
passed_tests=0,
|
||||
failed_tests=0,
|
||||
skipped_tests=0,
|
||||
total_execution_time=0.0,
|
||||
results=[],
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
results = []
|
||||
|
||||
if parallel and self.max_workers > 1:
|
||||
results = self._run_parallel(test_cases, fail_fast)
|
||||
else:
|
||||
results = self._run_sequential(test_cases, fail_fast)
|
||||
|
||||
# Calculate statistics
|
||||
total_tests = len(results)
|
||||
passed_tests = sum(1 for r in results if r.success and not r.test_case.skip)
|
||||
failed_tests = sum(1 for r in results if not r.success and not r.test_case.skip)
|
||||
skipped_tests = sum(1 for r in results if r.test_case.skip)
|
||||
total_execution_time = time.perf_counter() - start_time
|
||||
|
||||
return TestSuiteResult(
|
||||
total_tests=total_tests,
|
||||
passed_tests=passed_tests,
|
||||
failed_tests=failed_tests,
|
||||
skipped_tests=skipped_tests,
|
||||
total_execution_time=total_execution_time,
|
||||
results=results,
|
||||
)
|
||||
|
||||
def _run_sequential(self, test_cases: list[WorkflowTestCase], fail_fast: bool) -> list[WorkflowTestResult]:
|
||||
"""Run tests sequentially."""
|
||||
results = []
|
||||
|
||||
for test_case in test_cases:
|
||||
result = self.run_test_case(test_case)
|
||||
results.append(result)
|
||||
|
||||
if fail_fast and not result.success and not result.test_case.skip:
|
||||
self.logger.info("Fail-fast enabled: stopping execution")
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def _run_parallel(self, test_cases: list[WorkflowTestCase], fail_fast: bool) -> list[WorkflowTestResult]:
|
||||
"""Run tests in parallel."""
|
||||
results = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
future_to_test = {executor.submit(self.run_test_case, tc): tc for tc in test_cases}
|
||||
|
||||
for future in as_completed(future_to_test):
|
||||
test_case = future_to_test[future]
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
if fail_fast and not result.success and not result.test_case.skip:
|
||||
self.logger.info("Fail-fast enabled: cancelling remaining tests")
|
||||
# Cancel remaining futures
|
||||
for f in future_to_test:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception("Error in parallel execution for test: %s", test_case.description)
|
||||
results.append(
|
||||
WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=False,
|
||||
error=e,
|
||||
)
|
||||
)
|
||||
|
||||
if fail_fast:
|
||||
for f in future_to_test:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def generate_report(self, suite_result: TestSuiteResult) -> str:
|
||||
"""
|
||||
Generate a detailed test report.
|
||||
|
||||
Args:
|
||||
suite_result: Test suite results
|
||||
|
||||
Returns:
|
||||
Formatted report string
|
||||
"""
|
||||
report = []
|
||||
report.append("=" * 80)
|
||||
report.append("TEST SUITE REPORT")
|
||||
report.append("=" * 80)
|
||||
report.append("")
|
||||
|
||||
# Summary
|
||||
report.append("SUMMARY:")
|
||||
report.append(f" Total Tests: {suite_result.total_tests}")
|
||||
report.append(f" Passed: {suite_result.passed_tests}")
|
||||
report.append(f" Failed: {suite_result.failed_tests}")
|
||||
report.append(f" Skipped: {suite_result.skipped_tests}")
|
||||
report.append(f" Success Rate: {suite_result.success_rate:.1f}%")
|
||||
report.append(f" Total Time: {suite_result.total_execution_time:.2f}s")
|
||||
report.append("")
|
||||
|
||||
# Failed tests details
|
||||
failed_results = suite_result.get_failed_results()
|
||||
if failed_results:
|
||||
report.append("FAILED TESTS:")
|
||||
for result in failed_results:
|
||||
report.append(f" - {result.test_case.description}")
|
||||
if result.error:
|
||||
report.append(f" Error: {str(result.error)}")
|
||||
if result.validation_details:
|
||||
report.append(f" Validation: {result.validation_details}")
|
||||
if result.event_mismatch_details:
|
||||
report.append(f" Events: {result.event_mismatch_details}")
|
||||
report.append("")
|
||||
|
||||
# Performance metrics
|
||||
report.append("PERFORMANCE:")
|
||||
sorted_results = sorted(suite_result.results, key=lambda r: r.execution_time, reverse=True)[:5]
|
||||
|
||||
report.append(" Slowest Tests:")
|
||||
for result in sorted_results:
|
||||
report.append(f" - {result.test_case.description}: {result.execution_time:.2f}s")
|
||||
|
||||
report.append("=" * 80)
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _load_fixture(fixture_path: Path, fixture_name: str) -> dict[str, Any]:
|
||||
"""Load a YAML fixture file with caching to avoid repeated parsing."""
|
||||
if not fixture_path.exists():
|
||||
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
|
||||
|
||||
return _load_yaml_file(file_path=str(fixture_path))
|
||||
@ -0,0 +1,45 @@
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
|
||||
from .test_table_runner import TableTestRunner
|
||||
|
||||
|
||||
def test_tool_in_chatflow():
|
||||
runner = TableTestRunner()
|
||||
|
||||
# Load the workflow configuration
|
||||
fixture_data = runner.workflow_runner.load_fixture("chatflow_time_tool_static_output_workflow")
|
||||
|
||||
# Create graph from fixture with auto-mock enabled
|
||||
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
query="1",
|
||||
use_mock_factory=True,
|
||||
)
|
||||
|
||||
# Create and run the engine
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
# Check for successful completion
|
||||
success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)]
|
||||
assert len(success_events) > 0, "Workflow should complete successfully"
|
||||
|
||||
# Check for streaming events
|
||||
stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)]
|
||||
stream_chunk_count = len(stream_chunk_events)
|
||||
|
||||
assert stream_chunk_count == 1, f"Expected 1 streaming events, but got {stream_chunk_count}"
|
||||
assert stream_chunk_events[0].chunk == "hello, dify!", (
|
||||
f"Expected chunk to be 'hello, dify!', but got {stream_chunk_events[0].chunk}"
|
||||
)
|
||||
@ -0,0 +1,58 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
class TestVariableAggregator:
|
||||
"""Test cases for the variable aggregator workflow."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("switch1", "switch2", "expected_group1", "expected_group2", "description"),
|
||||
[
|
||||
(0, 0, "switch 1 off", "switch 2 off", "Both switches off"),
|
||||
(0, 1, "switch 1 off", "switch 2 on", "Switch1 off, Switch2 on"),
|
||||
(1, 0, "switch 1 on", "switch 2 off", "Switch1 on, Switch2 off"),
|
||||
(1, 1, "switch 1 on", "switch 2 on", "Both switches on"),
|
||||
],
|
||||
)
|
||||
def test_variable_aggregator_combinations(
|
||||
self,
|
||||
switch1: int,
|
||||
switch2: int,
|
||||
expected_group1: str,
|
||||
expected_group2: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
"""Test all four combinations of switch1 and switch2."""
|
||||
|
||||
def mock_template_transform_run(self):
|
||||
"""Mock the TemplateTransformNode._run() method to return results based on node title."""
|
||||
title = self._node_data.title
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title})
|
||||
|
||||
with patch.object(
|
||||
TemplateTransformNode,
|
||||
"_run",
|
||||
mock_template_transform_run,
|
||||
):
|
||||
runner = TableTestRunner()
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="dual_switch_variable_aggregator_workflow",
|
||||
inputs={"switch1": switch1, "switch2": switch2},
|
||||
expected_outputs={"group1": expected_group1, "group2": expected_group2},
|
||||
description=description,
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
assert result.actual_outputs == test_case.expected_outputs, (
|
||||
f"Output mismatch: expected {test_case.expected_outputs}, got {result.actual_outputs}"
|
||||
)
|
||||
@ -3,44 +3,41 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm-target",
|
||||
"id": "start-source-answer-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
"id": "llm",
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -50,13 +47,24 @@ def test_execute_answer():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
pool.add(["start", "weather"], "sunny")
|
||||
pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"id": "answer",
|
||||
@ -70,8 +78,7 @@ def test_execute_answer():
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
|
||||
@ -1,109 +0,0 @@
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
|
||||
|
||||
def test_init():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm4-target",
|
||||
"source": "llm3",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm5-target",
|
||||
"source": "llm3",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-answer2-target",
|
||||
"source": "llm4",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-answer-target",
|
||||
"source": "llm5",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "answer2-source-answer-target",
|
||||
"source": "answer2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
answer_stream_generate_route = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping
|
||||
)
|
||||
|
||||
assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"]
|
||||
assert answer_stream_generate_route.answer_dependencies["answer2"] == []
|
||||
@ -1,216 +0,0 @@
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
if next_node_id == "start":
|
||||
yield from _publish_events(graph, next_node_id)
|
||||
|
||||
for edge in graph.edge_mapping.get(next_node_id, []):
|
||||
yield from _publish_events(graph, edge.target_node_id)
|
||||
|
||||
for edge in graph.edge_mapping.get(next_node_id, []):
|
||||
yield from _recursive_process(graph, edge.target_node_id)
|
||||
|
||||
|
||||
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
route_node_state = RouteNodeState(node_id=next_node_id, start_at=naive_utc_now())
|
||||
|
||||
parallel_id = graph.node_parallel_mapping.get(next_node_id)
|
||||
parallel_start_node_id = None
|
||||
if parallel_id:
|
||||
parallel = graph.parallel_mapping.get(parallel_id)
|
||||
parallel_start_node_id = parallel.start_from_node_id if parallel else None
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
node_config = graph.node_id_config_mapping[next_node_id]
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
if "llm" in next_node_id:
|
||||
length = int(next_node_id[-1])
|
||||
for i in range(0, length):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
chunk_content=str(i),
|
||||
route_node_state=route_node_state,
|
||||
from_variable_selector=[next_node_id, "text"],
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||
route_node_state.finished_at = naive_utc_now()
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
|
||||
def test_process():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm4-target",
|
||||
"source": "llm3",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm5-target",
|
||||
"source": "llm3",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-answer2-target",
|
||||
"source": "llm4",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-answer-target",
|
||||
"source": "llm5",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "answer2-source-answer-target",
|
||||
"source": "answer2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="what's the weather in SF",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool)
|
||||
|
||||
def graph_generator() -> Generator[GraphEngineEvent, None, None]:
|
||||
# print("")
|
||||
for event in _recursive_process(graph, "start"):
|
||||
# print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id,
|
||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if "llm" in event.route_node_state.node_id:
|
||||
variable_pool.add(
|
||||
[event.route_node_state.node_id, "text"],
|
||||
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))),
|
||||
)
|
||||
yield event
|
||||
|
||||
result_generator = answer_stream_processor.process(graph_generator())
|
||||
stream_contents = ""
|
||||
for event in result_generator:
|
||||
# print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id,
|
||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
stream_contents += event.chunk_content
|
||||
pass
|
||||
|
||||
assert stream_contents == "c012da01b"
|
||||
@ -1,5 +1,5 @@
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
# Ensures that all node classes are imported.
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
@ -7,7 +7,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
_ = NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
||||
def _get_all_subclasses(root: type[BaseNode]) -> list[type[BaseNode]]:
|
||||
def _get_all_subclasses(root: type[Node]) -> list[type[Node]]:
|
||||
subclasses = []
|
||||
queue = [root]
|
||||
while queue:
|
||||
@ -20,16 +20,16 @@ def _get_all_subclasses(root: type[BaseNode]) -> list[type[BaseNode]]:
|
||||
|
||||
|
||||
def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined():
|
||||
classes = _get_all_subclasses(BaseNode) # type: ignore
|
||||
classes = _get_all_subclasses(Node) # type: ignore
|
||||
type_version_set: set[tuple[NodeType, str]] = set()
|
||||
|
||||
for cls in classes:
|
||||
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
||||
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
||||
node_type = cls._node_type
|
||||
node_type = cls.node_type
|
||||
node_version = cls.version()
|
||||
|
||||
assert isinstance(cls._node_type, NodeType)
|
||||
assert isinstance(cls.node_type, NodeType)
|
||||
assert isinstance(node_version, str)
|
||||
node_type_and_version = (node_type, node_version)
|
||||
assert node_type_and_version not in type_version_set
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
|
||||
@ -1,345 +0,0 @@
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileVariable, FileVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNode,
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_http_request_node_binary_file(monkeypatch: pytest.MonkeyPatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="binary",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
value="",
|
||||
file=["1111", "file"],
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == "test"
|
||||
|
||||
|
||||
def test_http_request_node_form_with_file(monkeypatch: pytest.MonkeyPatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="form-data",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
file=["1111", "file"],
|
||||
),
|
||||
BodyData(
|
||||
key="name",
|
||||
type="text",
|
||||
value="test",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
|
||||
def attr_checker(*args, **kwargs):
|
||||
assert kwargs["data"] == {"name": "test"}
|
||||
assert kwargs["files"] == [("file", (None, b"test", "application/octet-stream"))]
|
||||
return httpx.Response(200, content=b"")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
attr_checker,
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == ""
|
||||
|
||||
|
||||
def test_http_request_node_form_with_multiple_files(monkeypatch: pytest.MonkeyPatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/upload",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="form-data",
|
||||
data=[
|
||||
BodyData(
|
||||
key="files",
|
||||
type="file",
|
||||
file=["1111", "files"],
|
||||
),
|
||||
BodyData(
|
||||
key="name",
|
||||
type="text",
|
||||
value="test",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
files = [
|
||||
File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file1",
|
||||
filename="image1.jpg",
|
||||
mime_type="image/jpeg",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
tenant_id="1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file2",
|
||||
filename="document.pdf",
|
||||
mime_type="application/pdf",
|
||||
storage_key="",
|
||||
),
|
||||
]
|
||||
|
||||
variable_pool.add(
|
||||
["1111", "files"],
|
||||
ArrayFileVariable(
|
||||
name="files",
|
||||
value=files,
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",
|
||||
)
|
||||
|
||||
def attr_checker(*args, **kwargs):
|
||||
assert kwargs["data"] == {"name": "test"}
|
||||
|
||||
assert len(kwargs["files"]) == 2
|
||||
assert kwargs["files"][0][0] == "files"
|
||||
assert kwargs["files"][1][0] == "files"
|
||||
|
||||
file_tuples = [f[1] for f in kwargs["files"]]
|
||||
file_contents = [f[1] for f in file_tuples]
|
||||
file_types = [f[2] for f in file_tuples]
|
||||
|
||||
assert b"test_image_data" in file_contents
|
||||
assert b"test_pdf_data" in file_contents
|
||||
assert "image/jpeg" in file_types
|
||||
assert "application/pdf" in file_types
|
||||
|
||||
return httpx.Response(200, content=b'{"status":"success"}')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
attr_checker,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == '{"status":"success"}'
|
||||
print(result.outputs["body"])
|
||||
@ -1,887 +0,0 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables.segments import ArrayAnySegment, ArrayStringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_run():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
|
||||
"id": "answer-4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"iterator_selector": "dify"},
|
||||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
||||
count = 0
|
||||
for item in result:
|
||||
# print(type(item), item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
|
||||
assert count == 20
|
||||
|
||||
|
||||
def test_run_parallel():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-2-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt-2",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "tt-2-source-if-else-target",
|
||||
"source": "tt-2",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 321",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt-2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
|
||||
"id": "answer-4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"iterator_selector": "dify"},
|
||||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
||||
count = 0
|
||||
for item in result:
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
|
||||
assert count == 32
|
||||
|
||||
|
||||
def test_iteration_run_in_parallel_mode():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-2-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt-2",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "tt-2-source-if-else-target",
|
||||
"source": "tt-2",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 321",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt-2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
|
||||
"id": "answer-4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
parallel_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
parallel_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config=parallel_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
parallel_iteration_node.init_node_data(parallel_node_config["data"])
|
||||
sequential_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
sequential_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config=sequential_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
sequential_iteration_node.init_node_data(sequential_node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"iterator_selector": "dify"},
|
||||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
parallel_result = parallel_iteration_node._run()
|
||||
sequential_result = sequential_iteration_node._run()
|
||||
assert parallel_iteration_node._node_data.parallel_nums == 10
|
||||
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
count = 0
|
||||
parallel_arr = []
|
||||
sequential_arr = []
|
||||
for item in parallel_result:
|
||||
count += 1
|
||||
parallel_arr.append(item)
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
assert count == 32
|
||||
|
||||
for item in sequential_result:
|
||||
sequential_arr.append(item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
assert count == 64
|
||||
|
||||
|
||||
def test_iteration_run_error_handle():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "iteration-start",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "tt2",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt2", "output"],
|
||||
"output_type": "array[string]",
|
||||
"start_node_id": "if-else",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1.split(arg2) }}",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [
|
||||
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
|
||||
{"value_selector": ["iteration-1", "index"], "variable": "arg2"},
|
||||
],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }}",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [
|
||||
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
|
||||
],
|
||||
},
|
||||
"id": "tt2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "1",
|
||||
"variable_selector": ["iteration-1", "item"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["1", "1"])
|
||||
error_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config=error_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(error_node_config["data"])
|
||||
# execute continue on error node
|
||||
result = iteration_node._run()
|
||||
result_arr = []
|
||||
count = 0
|
||||
for item in result:
|
||||
result_arr.append(item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])}
|
||||
|
||||
assert count == 14
|
||||
# execute remove abnormal output
|
||||
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
result = iteration_node._run()
|
||||
count = 0
|
||||
for item in result:
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])}
|
||||
assert count == 14
|
||||
@ -26,14 +26,13 @@ def _gen_id():
|
||||
|
||||
|
||||
class TestFileSaverImpl:
|
||||
def test_save_binary_string(self, monkeypatch):
|
||||
def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch):
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
file_type = FileType.IMAGE
|
||||
mime_type = "image/png"
|
||||
mock_signed_url = "https://example.com/image.png"
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
@ -43,6 +42,7 @@ class TestFileSaverImpl:
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mock_tool_file.id = _gen_id()
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
@ -80,7 +80,7 @@ class TestFileSaverImpl:
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
@ -99,7 +99,7 @@ class TestFileSaverImpl:
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch):
|
||||
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mime_type = "image/png"
|
||||
user_id = _gen_id()
|
||||
@ -115,7 +115,6 @@ class TestFileSaverImpl:
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
@ -125,6 +124,7 @@ class TestFileSaverImpl:
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mock_tool_file.id = _gen_id()
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -21,10 +20,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
@ -39,7 +36,6 @@ from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class MockTokenBufferMemory:
|
||||
@ -47,7 +43,7 @@ class MockTokenBufferMemory:
|
||||
self.history_messages = history_messages or []
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
if message_limit is not None:
|
||||
return self.history_messages[-message_limit * 2 :]
|
||||
@ -69,6 +65,7 @@ def llm_node_data() -> LLMNodeData:
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
|
||||
|
||||
@ -77,7 +74,6 @@ def graph_init_params() -> GraphInitParams:
|
||||
return GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
@ -89,17 +85,10 @@ def graph_init_params() -> GraphInitParams:
|
||||
|
||||
@pytest.fixture
|
||||
def graph() -> Graph:
|
||||
return Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
)
|
||||
# TODO: This fixture uses old Graph constructor parameters that are incompatible
|
||||
# with the new queue-based engine. Need to rewrite for new engine architecture.
|
||||
pytest.skip("Graph fixture incompatible with new queue-based engine - needs rewrite for ResponseStreamCoordinator")
|
||||
return Graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -127,7 +116,6 @@ def llm_node(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
@ -517,7 +505,6 @@ def llm_node_for_multimodal(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
@ -689,3 +676,66 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
assert list(gen) == []
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
|
||||
class TestReasoningFormat:
|
||||
"""Test cases for reasoning_format functionality"""
|
||||
|
||||
def test_split_reasoning_separated_mode(self):
|
||||
"""Test separated mode: tags are removed and content is extracted"""
|
||||
|
||||
text_with_think = """
|
||||
<think>I need to explain what Dify is. It's an open source AI platform.
|
||||
</think>Dify is an open source AI platform.
|
||||
"""
|
||||
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "separated")
|
||||
|
||||
assert clean_text == "Dify is an open source AI platform."
|
||||
assert reasoning_content == "I need to explain what Dify is. It's an open source AI platform."
|
||||
|
||||
def test_split_reasoning_tagged_mode(self):
|
||||
"""Test tagged mode: original text is preserved"""
|
||||
|
||||
text_with_think = """
|
||||
<think>I need to explain what Dify is. It's an open source AI platform.
|
||||
</think>Dify is an open source AI platform.
|
||||
"""
|
||||
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "tagged")
|
||||
|
||||
# Original text unchanged
|
||||
assert clean_text == text_with_think
|
||||
# Empty reasoning content in tagged mode
|
||||
assert reasoning_content == ""
|
||||
|
||||
def test_split_reasoning_no_think_blocks(self):
|
||||
"""Test behavior when no <think> tags are present"""
|
||||
|
||||
text_without_think = "This is a simple answer without any thinking blocks."
|
||||
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_without_think, "separated")
|
||||
|
||||
assert clean_text == text_without_think
|
||||
assert reasoning_content == ""
|
||||
|
||||
def test_reasoning_format_default_value(self):
|
||||
"""Test that reasoning_format defaults to 'tagged' for backward compatibility"""
|
||||
|
||||
node_data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
)
|
||||
|
||||
assert node_data.reasoning_format == "tagged"
|
||||
|
||||
text_with_think = """
|
||||
<think>I need to explain what Dify is. It's an open source AI platform.
|
||||
</think>Dify is an open source AI platform.
|
||||
"""
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, node_data.reasoning_format)
|
||||
|
||||
assert clean_text == text_with_think
|
||||
assert reasoning_content == ""
|
||||
|
||||
@ -1,91 +0,0 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-answer-target",
|
||||
"source": "start",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
node_config = {
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
}
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
|
||||
@ -1,560 +0,0 @@
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class ContinueOnErrorTestHelper:
|
||||
@staticmethod
|
||||
def get_code_node(
|
||||
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
|
||||
):
|
||||
"""Helper method to create a code node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"outputs": {"result": {"type": "number"}},
|
||||
"error_strategy": error_strategy,
|
||||
"title": "code",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
||||
"type": "code",
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_http_node(
|
||||
error_strategy: str = "fail-branch",
|
||||
default_value: dict | None = None,
|
||||
authorization_success: bool = False,
|
||||
retry_config: dict = {},
|
||||
):
|
||||
"""Helper method to create a http node configuration"""
|
||||
authorization = (
|
||||
{
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
}
|
||||
if authorization_success
|
||||
else {
|
||||
"type": "api-key",
|
||||
# missing config field
|
||||
}
|
||||
)
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": authorization,
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": None,
|
||||
"type": "http-request",
|
||||
"error_strategy": error_strategy,
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
"""Helper method to create a http node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "HTTP Request",
|
||||
"desc": "",
|
||||
"variables": [],
|
||||
"method": "get",
|
||||
"url": "https://api.github.com/issues",
|
||||
"authorization": {"type": "no-auth", "config": None},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": {"type": "none", "data": []},
|
||||
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
|
||||
"error_strategy": error_strategy,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
"""Helper method to create a tool node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "maths",
|
||||
"provider_type": "builtin",
|
||||
"provider_name": "maths",
|
||||
"tool_name": "eval_expression",
|
||||
"tool_label": "eval_expression",
|
||||
"tool_configurations": {},
|
||||
"tool_parameters": {
|
||||
"expression": {
|
||||
"type": "variable",
|
||||
"value": ["1", "123", "args1"],
|
||||
}
|
||||
},
|
||||
"type": "tool",
|
||||
"error_strategy": error_strategy,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node.node_data.default_value = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
"""Helper method to create a llm node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
"error_strategy": error_strategy,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
|
||||
"""Helper method to create a graph engine instance for testing"""
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="clear",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs=user_inputs or {"uid": "takato"},
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
return GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
FAIL_BRANCH_EDGES = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-true-success-target",
|
||||
"source": "node",
|
||||
"target": "success",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-false-error-target",
|
||||
"source": "node",
|
||||
"target": "error",
|
||||
"sourceHandle": "fail-branch",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_code_default_value_continue_on_error():
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_code_node(
|
||||
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
def test_code_fail_branch_continue_on_error():
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_code_node(error_code),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
|
||||
)
|
||||
|
||||
|
||||
def test_http_node_default_value_continue_on_error():
|
||||
"""Test HTTP node with default value error strategy"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
|
||||
for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
def test_http_node_fail_branch_continue_on_error():
|
||||
"""Test HTTP node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
# def test_tool_node_default_value_continue_on_error():
|
||||
# """Test tool node with default value error strategy"""
|
||||
# graph_config = {
|
||||
# "edges": DEFAULT_VALUE_EDGE,
|
||||
# "nodes": [
|
||||
# {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
# {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
# ContinueOnErrorTestHelper.get_tool_node(
|
||||
# "default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
|
||||
# ),
|
||||
# ],
|
||||
# }
|
||||
|
||||
# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
# events = list(graph_engine.run())
|
||||
|
||||
# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
# assert any(
|
||||
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events # noqa: E501
|
||||
# )
|
||||
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
# def test_tool_node_fail_branch_continue_on_error():
|
||||
# """Test HTTP node with fail-branch error strategy"""
|
||||
# graph_config = {
|
||||
# "edges": FAIL_BRANCH_EDGES,
|
||||
# "nodes": [
|
||||
# {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
# {
|
||||
# "data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
|
||||
# "id": "success",
|
||||
# },
|
||||
# {
|
||||
# "data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
|
||||
# "id": "error",
|
||||
# },
|
||||
# ContinueOnErrorTestHelper.get_tool_node(),
|
||||
# ],
|
||||
# }
|
||||
|
||||
# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
# events = list(graph_engine.run())
|
||||
|
||||
# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
# assert any(
|
||||
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events # noqa: E501
|
||||
# )
|
||||
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
def test_llm_node_default_value_continue_on_error():
|
||||
"""Test LLM node with default value error strategy"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_llm_node(
|
||||
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
def test_llm_node_fail_branch_continue_on_error():
|
||||
"""Test LLM node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_llm_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
def test_status_code_error_http_node_fail_branch_continue_on_error():
|
||||
"""Test HTTP node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
def test_variable_pool_error_type_variable():
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
list(graph_engine.run())
|
||||
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
|
||||
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
|
||||
assert error_message != None
|
||||
assert error_type.value == "HTTPResponseCodeError"
|
||||
|
||||
|
||||
def test_no_node_in_fail_branch_continue_on_error():
|
||||
"""Test HTTP node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES[:-1],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
|
||||
ContinueOnErrorTestHelper.get_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
||||
|
||||
|
||||
def test_stream_output_with_fail_branch_continue_on_error():
|
||||
"""Test stream output with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_llm_node(),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
|
||||
def llm_generator(self):
|
||||
contents = ["hi", "bye", "good morning"]
|
||||
|
||||
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 1,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch.object(LLMNode, "_run", new=llm_generator):
|
||||
events = list(graph_engine.run())
|
||||
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
|
||||
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)
|
||||
@ -5,12 +5,14 @@ import pandas as pd
|
||||
import pytest
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
|
||||
from core.workflow.nodes.document_extractor.node import (
|
||||
_extract_text_from_docx,
|
||||
@ -18,11 +20,25 @@ from core.workflow.nodes.document_extractor.node import (
|
||||
_extract_text_from_pdf,
|
||||
_extract_text_from_plain_text,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_extractor_node():
|
||||
def graph_init_params() -> GraphInitParams:
|
||||
return GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_extractor_node(graph_init_params):
|
||||
node_data = DocumentExtractorNodeData(
|
||||
title="Test Document Extractor",
|
||||
variable_selector=["node_id", "variable_name"],
|
||||
@ -31,8 +47,7 @@ def document_extractor_node():
|
||||
node = DocumentExtractorNode(
|
||||
id="test_node_id",
|
||||
config=node_config,
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
# Initialize node data
|
||||
@ -201,7 +216,7 @@ def test_extract_text_from_docx(mock_document):
|
||||
|
||||
|
||||
def test_node_type(document_extractor_node):
|
||||
assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR
|
||||
assert document_extractor_node.node_type == NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
|
||||
@ -7,29 +7,24 @@ import pytest
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_execute_if_else_result_true():
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -59,6 +54,13 @@ def test_execute_if_else_result_true():
|
||||
pool.add(["start", "null"], None)
|
||||
pool.add(["start", "not_null"], "1212")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
@ -107,8 +109,7 @@ def test_execute_if_else_result_true():
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
@ -127,31 +128,12 @@ def test_execute_if_else_result_true():
|
||||
|
||||
|
||||
def test_execute_if_else_result_false():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
# Create a simple graph for IfElse node testing
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -169,6 +151,13 @@ def test_execute_if_else_result_false():
|
||||
pool.add(["start", "array_contains"], ["1ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ab", "def"])
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
@ -193,8 +182,7 @@ def test_execute_if_else_result_false():
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
@ -245,10 +233,20 @@ def test_array_file_contains_file_name():
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Create properly configured mock for graph_init_params
|
||||
graph_init_params = Mock()
|
||||
graph_init_params.tenant_id = "test_tenant"
|
||||
graph_init_params.app_id = "test_app"
|
||||
graph_init_params.workflow_id = "test_workflow"
|
||||
graph_init_params.graph_config = {}
|
||||
graph_init_params.user_id = "test_user"
|
||||
graph_init_params.user_from = UserFrom.ACCOUNT
|
||||
graph_init_params.invoke_from = InvokeFrom.SERVICE_API
|
||||
graph_init_params.call_depth = 0
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=Mock(),
|
||||
config=node_config,
|
||||
)
|
||||
@ -276,7 +274,7 @@ def test_array_file_contains_file_name():
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def _get_test_conditions() -> list:
|
||||
def _get_test_conditions():
|
||||
conditions = [
|
||||
# Test boolean "is" operator
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},
|
||||
@ -307,14 +305,11 @@ def _get_condition_test_id(c: Condition):
|
||||
@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id)
|
||||
def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
"""Test IfElseNode with boolean conditions using various operators"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -332,6 +327,13 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
pool.add(["start", "mixed_array"], [True, "false", 1, 0])
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Test",
|
||||
"type": "if-else",
|
||||
@ -341,8 +343,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
@ -360,14 +361,11 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
|
||||
def test_execute_if_else_boolean_false_conditions():
|
||||
"""Test IfElseNode with boolean conditions that should evaluate to false"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -384,6 +382,13 @@ def test_execute_if_else_boolean_false_conditions():
|
||||
pool.add(["start", "bool_false"], False)
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean False Test",
|
||||
"type": "if-else",
|
||||
@ -405,8 +410,7 @@ def test_execute_if_else_boolean_false_conditions():
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data,
|
||||
@ -427,14 +431,11 @@ def test_execute_if_else_boolean_false_conditions():
|
||||
|
||||
def test_execute_if_else_boolean_cases_structure():
|
||||
"""Test IfElseNode with boolean conditions using the new cases structure"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -450,6 +451,13 @@ def test_execute_if_else_boolean_cases_structure():
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Cases Test",
|
||||
"type": "if-else",
|
||||
@ -475,8 +483,7 @@ def test_execute_if_else_boolean_cases_structure():
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
@ -2,9 +2,10 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.list_operator.entities import (
|
||||
ExtractConfig,
|
||||
FilterBy,
|
||||
@ -16,6 +17,7 @@ from core.workflow.nodes.list_operator.entities import (
|
||||
)
|
||||
from core.workflow.nodes.list_operator.exc import InvalidKeyError
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -38,11 +40,21 @@ def list_operator_node():
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
# Create properly configured mock for graph_init_params
|
||||
graph_init_params = MagicMock()
|
||||
graph_init_params.tenant_id = "test_tenant"
|
||||
graph_init_params.app_id = "test_app"
|
||||
graph_init_params.workflow_id = "test_workflow"
|
||||
graph_init_params.graph_config = {}
|
||||
graph_init_params.user_id = "test_user"
|
||||
graph_init_params.user_from = UserFrom.ACCOUNT
|
||||
graph_init_params.invoke_from = InvokeFrom.SERVICE_API
|
||||
graph_init_params.call_depth = 0
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test_node_id",
|
||||
config=node_config,
|
||||
graph_init_params=MagicMock(),
|
||||
graph=MagicMock(),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
# Initialize node data
|
||||
|
||||
@ -1,65 +0,0 @@
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_retry_default_value_partial_success():
|
||||
"""retry default value node with partial success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value",
|
||||
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert events[-1].outputs == {"answer": "http node got error response"}
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
def test_retry_failed():
|
||||
"""retry failed with success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
None,
|
||||
None,
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||
assert len(events) == 8
|
||||
@ -1,115 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models import UserFrom, WorkflowType
|
||||
|
||||
|
||||
def _create_tool_node():
|
||||
data = ToolNodeData(
|
||||
title="Test Tool",
|
||||
tool_parameters={},
|
||||
provider_id="test_tool",
|
||||
provider_type=ToolProviderType.WORKFLOW,
|
||||
provider_name="test tool",
|
||||
tool_name="test tool",
|
||||
tool_label="test tool",
|
||||
tool_configurations={},
|
||||
plugin_unique_identifier=None,
|
||||
desc="Exception handling test tool",
|
||||
error_strategy=ErrorStrategy.FAIL_BRANCH,
|
||||
version="1",
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
node = ToolNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
class MockToolRuntime:
|
||||
def get_merged_runtime_parameters(self):
|
||||
pass
|
||||
|
||||
|
||||
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from []
|
||||
raise ToolInvokeError("oops")
|
||||
|
||||
|
||||
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that ToolNode can handle ToolInvokeError when transforming
|
||||
messages generated by ToolEngine.generic_invoke.
|
||||
"""
|
||||
tool_node = _create_tool_node()
|
||||
|
||||
# Need to patch ToolManager and ToolEngine so that we don't
|
||||
# have to set up a database.
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_engine.ToolEngine.generic_invoke",
|
||||
lambda *args, **kwargs: mock_message_stream(),
|
||||
)
|
||||
|
||||
streams = list(tool_node._run())
|
||||
assert len(streams) == 1
|
||||
stream = streams[0]
|
||||
assert isinstance(stream, RunCompletedEvent)
|
||||
result = stream.run_result
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "oops" in result.error
|
||||
assert "Failed to invoke tool" in result.error
|
||||
assert result.error_type == "ToolInvokeError"
|
||||
@ -6,15 +6,13 @@ from uuid import uuid4
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
DEFAULT_NODE_ID = "node_id"
|
||||
|
||||
@ -29,22 +27,17 @@ def test_overwrite_string_variable():
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"data": {"type": "assigner", "version": "1", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -79,6 +72,13 @@ def test_overwrite_string_variable():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
@ -95,8 +95,7 @@ def test_overwrite_string_variable():
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
@ -132,22 +131,17 @@ def test_append_variable_to_array():
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"data": {"type": "assigner", "version": "1", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -180,6 +174,13 @@ def test_append_variable_to_array():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
@ -196,8 +197,7 @@ def test_append_variable_to_array():
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
@ -234,22 +234,17 @@ def test_clear_array():
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"data": {"type": "assigner", "version": "1", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -272,6 +267,13 @@ def test_clear_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
@ -288,8 +290,7 @@ def test_clear_array():
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
@ -4,15 +4,13 @@ from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
DEFAULT_NODE_ID = "node_id"
|
||||
|
||||
@ -77,22 +75,17 @@ def test_remove_first_from_array():
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"data": {"type": "assigner", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -115,6 +108,13 @@ def test_remove_first_from_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -134,8 +134,7 @@ def test_remove_first_from_array():
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
@ -143,15 +142,11 @@ def test_remove_first_from_array():
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
# Print the variable before running
|
||||
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
|
||||
# Run the node
|
||||
result = list(node.run())
|
||||
|
||||
# Print the variable after running and the result
|
||||
print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
print(f"Result: {result}")
|
||||
# Completed run
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@ -169,22 +164,17 @@ def test_remove_last_from_array():
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"data": {"type": "assigner", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -207,6 +197,13 @@ def test_remove_last_from_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -226,8 +223,7 @@ def test_remove_last_from_array():
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
@ -253,22 +249,17 @@ def test_remove_first_from_empty_array():
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"data": {"type": "assigner", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -291,6 +282,13 @@ def test_remove_first_from_empty_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -310,8 +308,7 @@ def test_remove_first_from_empty_array():
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
@ -337,22 +334,17 @@ def test_remove_last_from_empty_array():
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"data": {"type": "assigner", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -375,6 +367,13 @@ def test_remove_last_from_empty_array():
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -394,8 +393,7 @@ def test_remove_last_from_empty_array():
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from core.variables.variables import (
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories.variable_factory import build_segment, segment_to_variable
|
||||
|
||||
@ -68,18 +68,6 @@ def test_get_file_attribute(pool, file):
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_use_long_selector(pool):
|
||||
# The add method now only accepts 2-element selectors (node_id, variable_name)
|
||||
# Store nested data as an ObjectSegment instead
|
||||
nested_data = {"part_2": "test_value"}
|
||||
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
|
||||
|
||||
# The get method supports longer selectors for nested access
|
||||
result = pool.get(("node_1", "part_1", "part_2"))
|
||||
assert result is not None
|
||||
assert result.value == "test_value"
|
||||
|
||||
|
||||
class TestVariablePool:
|
||||
def test_constructor(self):
|
||||
# Test with minimal required SystemVariable
|
||||
@ -284,11 +272,6 @@ class TestVariablePoolSerialization:
|
||||
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
||||
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
||||
|
||||
# Add nested variables as ObjectSegment
|
||||
# The add method only accepts 2-element selectors
|
||||
nested_obj = {"deep": {"var": "deep_value"}}
|
||||
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
|
||||
|
||||
def test_system_variables(self):
|
||||
sys_vars = SystemVariable(
|
||||
user_id="test_user_id",
|
||||
@ -379,7 +362,7 @@ class TestVariablePoolSerialization:
|
||||
self._assert_pools_equal(reconstructed_dict, reconstructed_json)
|
||||
# TODO: assert the data for file object...
|
||||
|
||||
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
|
||||
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool):
|
||||
"""Assert that two VariablePools contain equivalent data"""
|
||||
|
||||
# Compare system variables
|
||||
@ -406,7 +389,6 @@ class TestVariablePoolSerialization:
|
||||
(self._NODE1_ID, "float_var"),
|
||||
(self._NODE2_ID, "array_string"),
|
||||
(self._NODE2_ID, "array_number"),
|
||||
(self._NODE3_ID, "nested", "deep", "var"),
|
||||
]
|
||||
|
||||
for selector in test_selectors:
|
||||
@ -442,3 +424,13 @@ class TestVariablePoolSerialization:
|
||||
loaded = VariablePool.model_validate(pool_dict)
|
||||
assert isinstance(loaded.variable_dictionary, defaultdict)
|
||||
loaded.add(["non_exist_node", "a"], 1)
|
||||
|
||||
|
||||
def test_get_attr():
|
||||
vp = VariablePool()
|
||||
value = {"output": StringSegment(value="hello")}
|
||||
|
||||
vp.add(["node", "name"], value)
|
||||
res = vp.get(["node", "name", "output"])
|
||||
assert res is not None
|
||||
assert res.value == "hello"
|
||||
|
||||
@ -11,11 +11,15 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
from core.workflow.entities import (
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
@ -93,7 +97,7 @@ def mock_workflow_execution_repository():
|
||||
def real_workflow_entity():
|
||||
return CycleManagerWorkflowInfo(
|
||||
workflow_id="test-workflow-id", # Matches ID used in other fixtures
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version="1.0.0",
|
||||
graph_data={
|
||||
"nodes": [
|
||||
@ -207,8 +211,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
@ -241,8 +245,8 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
@ -278,8 +282,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
@ -293,12 +297,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
||||
event.node_execution_id = "test-node-execution-id"
|
||||
event.node_id = "test-node-id"
|
||||
event.node_type = NodeType.LLM
|
||||
|
||||
# Create node_data as a separate mock
|
||||
node_data = MagicMock()
|
||||
node_data.title = "Test Node"
|
||||
event.node_data = node_data
|
||||
|
||||
event.node_title = "Test Node"
|
||||
event.predecessor_node_id = "test-predecessor-node-id"
|
||||
event.node_run_index = 1
|
||||
event.parallel_mode_run_id = "test-parallel-mode-run-id"
|
||||
@ -317,7 +316,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
||||
assert result.node_execution_id == event.node_execution_id
|
||||
assert result.node_id == event.node_id
|
||||
assert result.node_type == event.node_type
|
||||
assert result.title == event.node_data.title
|
||||
assert result.title == event.node_title
|
||||
assert result.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
# Verify save was called
|
||||
@ -331,8 +330,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
@ -405,8 +404,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
|
||||
456
api/tests/unit_tests/core/workflow/test_workflow_entry.py
Normal file
456
api/tests/unit_tests/core/workflow/test_workflow_entry.py
Normal file
@ -0,0 +1,456 @@
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.models import File, FileTransferMethod
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
|
||||
class TestWorkflowEntry:
|
||||
"""Test WorkflowEntry class methods."""
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
|
||||
"""Test mapping system variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with system variables
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="test_user_id",
|
||||
app_id="test_app_id",
|
||||
workflow_id="test_workflow_id",
|
||||
),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping - sys variables mapped to other nodes
|
||||
variable_mapping = {
|
||||
"node1.input1": ["node1", "input1"], # Regular mapping
|
||||
"node2.query": ["node2", "query"], # Regular mapping
|
||||
"sys.user_id": ["output_node", "user"], # System variable mapping
|
||||
}
|
||||
|
||||
# User inputs including sys variables
|
||||
user_inputs = {
|
||||
"node1.input1": "new_user_id",
|
||||
"node2.query": "test query",
|
||||
"sys.user_id": "system_user",
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variables were added to pool
|
||||
# Note: variable_pool.get returns Variable objects, not raw values
|
||||
node1_var = variable_pool.get(["node1", "input1"])
|
||||
assert node1_var is not None
|
||||
assert node1_var.value == "new_user_id"
|
||||
|
||||
node2_var = variable_pool.get(["node2", "query"])
|
||||
assert node2_var is not None
|
||||
assert node2_var.value == "test query"
|
||||
|
||||
# System variable gets mapped to output node
|
||||
output_var = variable_pool.get(["output_node", "user"])
|
||||
assert output_var is not None
|
||||
assert output_var.value == "system_user"
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
|
||||
"""Test mapping environment variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with environment variables
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
environment_variables=[env_var],
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Add env variable to pool (simulating initialization)
|
||||
variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var)
|
||||
|
||||
# Define variable mapping - env variables should not be overridden
|
||||
variable_mapping = {
|
||||
"node1.api_key": [ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"],
|
||||
"node2.new_env": [ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"],
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"node1.api_key": "user_provided_key", # This should not override existing env var
|
||||
"node2.new_env": "new_env_value",
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify env variable was not overridden
|
||||
env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"])
|
||||
assert env_value is not None
|
||||
assert env_value.value == "existing_key" # Should remain unchanged
|
||||
|
||||
# New env variables from user input should not be added
|
||||
assert variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"]) is None
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_conversation_variables(self):
|
||||
"""Test mapping conversation variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with conversation variables
|
||||
conv_var = StringVariable(name="last_message", value="Hello")
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
conversation_variables=[conv_var],
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Add conversation variable to pool
|
||||
variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "last_message"], conv_var)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"node1.message": ["node1", "message"], # Map to regular node
|
||||
"conversation.context": ["chat_node", "context"], # Conversation var to regular node
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"node1.message": "Updated message",
|
||||
"conversation.context": "New context",
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variables were added to their target nodes
|
||||
node1_var = variable_pool.get(["node1", "message"])
|
||||
assert node1_var is not None
|
||||
assert node1_var.value == "Updated message"
|
||||
|
||||
chat_var = variable_pool.get(["chat_node", "context"])
|
||||
assert chat_var is not None
|
||||
assert chat_var.value == "New context"
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
|
||||
"""Test mapping regular node variables from user inputs to variable pool."""
|
||||
# Initialize empty variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping for regular nodes
|
||||
variable_mapping = {
|
||||
"input_node.text": ["input_node", "text"],
|
||||
"llm_node.prompt": ["llm_node", "prompt"],
|
||||
"code_node.input": ["code_node", "input"],
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"input_node.text": "User input text",
|
||||
"llm_node.prompt": "Generate a summary",
|
||||
"code_node.input": {"key": "value"},
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify regular variables were added
|
||||
text_var = variable_pool.get(["input_node", "text"])
|
||||
assert text_var is not None
|
||||
assert text_var.value == "User input text"
|
||||
|
||||
prompt_var = variable_pool.get(["llm_node", "prompt"])
|
||||
assert prompt_var is not None
|
||||
assert prompt_var.value == "Generate a summary"
|
||||
|
||||
input_var = variable_pool.get(["code_node", "input"])
|
||||
assert input_var is not None
|
||||
assert input_var.value == {"key": "value"}
|
||||
|
||||
def test_mapping_user_inputs_with_file_handling(self):
|
||||
"""Test mapping file inputs from user inputs to variable pool."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"file_node.file": ["file_node", "file"],
|
||||
"file_node.files": ["file_node", "files"],
|
||||
}
|
||||
|
||||
# User inputs with file data - using remote_url which doesn't require upload_file_id
|
||||
user_inputs = {
|
||||
"file_node.file": {
|
||||
"type": "document",
|
||||
"transfer_method": "remote_url",
|
||||
"url": "http://example.com/test.pdf",
|
||||
},
|
||||
"file_node.files": [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote_url",
|
||||
"url": "http://example.com/image1.jpg",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote_url",
|
||||
"url": "http://example.com/image2.jpg",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify file was converted and added
|
||||
file_var = variable_pool.get(["file_node", "file"])
|
||||
assert file_var is not None
|
||||
assert file_var.value.type == FileType.DOCUMENT
|
||||
assert file_var.value.transfer_method == FileTransferMethod.REMOTE_URL
|
||||
|
||||
# Verify file list was converted and added
|
||||
files_var = variable_pool.get(["file_node", "files"])
|
||||
assert files_var is not None
|
||||
assert isinstance(files_var.value, list)
|
||||
assert len(files_var.value) == 2
|
||||
assert all(isinstance(f, File) for f in files_var.value)
|
||||
assert files_var.value[0].type == FileType.IMAGE
|
||||
assert files_var.value[1].type == FileType.IMAGE
|
||||
assert files_var.value[0].type == FileType.IMAGE
|
||||
assert files_var.value[1].type == FileType.IMAGE
|
||||
|
||||
def test_mapping_user_inputs_missing_variable_error(self):
|
||||
"""Test that mapping raises error when required variable is missing."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"node1.required_input": ["node1", "required_input"],
|
||||
}
|
||||
|
||||
# User inputs without required variable
|
||||
user_inputs = {
|
||||
"node1.other_input": "some value",
|
||||
}
|
||||
|
||||
# Should raise ValueError for missing variable
|
||||
with pytest.raises(ValueError, match="Variable key node1.required_input not found in user inputs"):
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
def test_mapping_user_inputs_with_alternative_key_format(self):
|
||||
"""Test mapping with alternative key format (without node prefix)."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping
|
||||
variable_mapping = {
|
||||
"node1.input": ["node1", "input"],
|
||||
}
|
||||
|
||||
# User inputs with alternative key format
|
||||
user_inputs = {
|
||||
"input": "value without node prefix", # Alternative format without node prefix
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variable was added using alternative key
|
||||
input_var = variable_pool.get(["node1", "input"])
|
||||
assert input_var is not None
|
||||
assert input_var.value == "value without node prefix"
|
||||
|
||||
def test_mapping_user_inputs_with_complex_selectors(self):
|
||||
"""Test mapping with complex node variable keys."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping - selectors can only have 2 elements
|
||||
variable_mapping = {
|
||||
"node1.data.field1": ["node1", "data_field1"], # Complex key mapped to simple selector
|
||||
"node2.config.settings.timeout": ["node2", "timeout"], # Complex key mapped to simple selector
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"node1.data.field1": "nested value",
|
||||
"node2.config.settings.timeout": 30,
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify variables were added with simple selectors
|
||||
data_var = variable_pool.get(["node1", "data_field1"])
|
||||
assert data_var is not None
|
||||
assert data_var.value == "nested value"
|
||||
|
||||
timeout_var = variable_pool.get(["node2", "timeout"])
|
||||
assert timeout_var is not None
|
||||
assert timeout_var.value == 30
|
||||
|
||||
def test_mapping_user_inputs_invalid_node_variable(self):
|
||||
"""Test that mapping handles invalid node variable format."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Define variable mapping with single element node variable (at least one dot is required)
|
||||
variable_mapping = {
|
||||
"singleelement": ["node1", "input"], # No dot separator
|
||||
}
|
||||
|
||||
user_inputs = {"singleelement": "some value"} # Must use exact key
|
||||
|
||||
# Should NOT raise error - function accepts it and uses direct key
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify it was added
|
||||
var = variable_pool.get(["node1", "input"])
|
||||
assert var is not None
|
||||
assert var.value == "some value"
|
||||
|
||||
def test_mapping_all_variable_types_together(self):
|
||||
"""Test mapping all four types of variables in one operation."""
|
||||
# Initialize variable pool with some existing variables
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
conv_var = StringVariable(name="session_id", value="session123")
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="test_user",
|
||||
app_id="test_app",
|
||||
query="initial query",
|
||||
),
|
||||
environment_variables=[env_var],
|
||||
conversation_variables=[conv_var],
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
# Add existing variables to pool
|
||||
variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var)
|
||||
variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "session_id"], conv_var)
|
||||
|
||||
# Define comprehensive variable mapping
|
||||
variable_mapping = {
|
||||
# System variables mapped to regular nodes
|
||||
"sys.user_id": ["start", "user"],
|
||||
"sys.app_id": ["start", "app"],
|
||||
# Environment variables (won't be overridden)
|
||||
"env.API_KEY": ["config", "api_key"],
|
||||
# Conversation variables mapped to regular nodes
|
||||
"conversation.session_id": ["chat", "session"],
|
||||
# Regular variables
|
||||
"input.text": ["input", "text"],
|
||||
"process.data": ["process", "data"],
|
||||
}
|
||||
|
||||
# User inputs
|
||||
user_inputs = {
|
||||
"sys.user_id": "new_user",
|
||||
"sys.app_id": "new_app",
|
||||
"env.API_KEY": "attempted_override", # Should not override env var
|
||||
"conversation.session_id": "new_session",
|
||||
"input.text": "user input text",
|
||||
"process.data": {"value": 123, "status": "active"},
|
||||
}
|
||||
|
||||
# Execute mapping
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# Verify system variables were added to their target nodes
|
||||
start_user = variable_pool.get(["start", "user"])
|
||||
assert start_user is not None
|
||||
assert start_user.value == "new_user"
|
||||
|
||||
start_app = variable_pool.get(["start", "app"])
|
||||
assert start_app is not None
|
||||
assert start_app.value == "new_app"
|
||||
|
||||
# Verify env variable was not overridden (still has original value)
|
||||
env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"])
|
||||
assert env_value is not None
|
||||
assert env_value.value == "existing_key"
|
||||
|
||||
# Environment variables get mapped to other nodes even when they exist in env pool
|
||||
# But the original env value remains unchanged
|
||||
config_api_key = variable_pool.get(["config", "api_key"])
|
||||
assert config_api_key is not None
|
||||
assert config_api_key.value == "attempted_override"
|
||||
|
||||
# Verify conversation variable was mapped to target node
|
||||
chat_session = variable_pool.get(["chat", "session"])
|
||||
assert chat_session is not None
|
||||
assert chat_session.value == "new_session"
|
||||
|
||||
# Verify regular variables were added
|
||||
input_text = variable_pool.get(["input", "text"])
|
||||
assert input_text is not None
|
||||
assert input_text.value == "user input text"
|
||||
|
||||
process_data = variable_pool.get(["process", "data"])
|
||||
assert process_data is not None
|
||||
assert process_data.value == {"value": 123, "status": "active"}
|
||||
@ -0,0 +1,144 @@
|
||||
"""Tests for WorkflowEntry integration with Redis command channel."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class TestWorkflowEntryRedisChannel:
|
||||
"""Test suite for WorkflowEntry with Redis command channel."""
|
||||
|
||||
def test_workflow_entry_uses_provided_redis_channel(self):
|
||||
"""Test that WorkflowEntry uses the provided Redis command channel."""
|
||||
# Mock dependencies
|
||||
mock_graph = MagicMock()
|
||||
mock_graph_config = {"nodes": [], "edges": []}
|
||||
mock_variable_pool = MagicMock(spec=VariablePool)
|
||||
mock_graph_runtime_state = MagicMock(spec=GraphRuntimeState)
|
||||
mock_graph_runtime_state.variable_pool = mock_variable_pool
|
||||
|
||||
# Create a mock Redis channel
|
||||
mock_redis_client = MagicMock()
|
||||
redis_channel = RedisChannel(mock_redis_client, "test:channel:key")
|
||||
|
||||
# Patch GraphEngine to verify it receives the Redis channel
|
||||
with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine:
|
||||
mock_graph_engine = MagicMock()
|
||||
MockGraphEngine.return_value = mock_graph_engine
|
||||
|
||||
# Create WorkflowEntry with Redis channel
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
workflow_id="test-workflow",
|
||||
graph_config=mock_graph_config,
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
variable_pool=mock_variable_pool,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
command_channel=redis_channel, # Provide Redis channel
|
||||
)
|
||||
|
||||
# Verify GraphEngine was initialized with the Redis channel
|
||||
MockGraphEngine.assert_called_once()
|
||||
call_args = MockGraphEngine.call_args[1]
|
||||
assert call_args["command_channel"] == redis_channel
|
||||
assert workflow_entry.command_channel == redis_channel
|
||||
|
||||
def test_workflow_entry_defaults_to_inmemory_channel(self):
|
||||
"""Test that WorkflowEntry defaults to InMemoryChannel when no channel is provided."""
|
||||
# Mock dependencies
|
||||
mock_graph = MagicMock()
|
||||
mock_graph_config = {"nodes": [], "edges": []}
|
||||
mock_variable_pool = MagicMock(spec=VariablePool)
|
||||
mock_graph_runtime_state = MagicMock(spec=GraphRuntimeState)
|
||||
mock_graph_runtime_state.variable_pool = mock_variable_pool
|
||||
|
||||
# Patch GraphEngine and InMemoryChannel
|
||||
with (
|
||||
patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine,
|
||||
patch("core.workflow.workflow_entry.InMemoryChannel") as MockInMemoryChannel,
|
||||
):
|
||||
mock_graph_engine = MagicMock()
|
||||
MockGraphEngine.return_value = mock_graph_engine
|
||||
mock_inmemory_channel = MagicMock()
|
||||
MockInMemoryChannel.return_value = mock_inmemory_channel
|
||||
|
||||
# Create WorkflowEntry without providing a channel
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
workflow_id="test-workflow",
|
||||
graph_config=mock_graph_config,
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
variable_pool=mock_variable_pool,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
command_channel=None, # No channel provided
|
||||
)
|
||||
|
||||
# Verify InMemoryChannel was created
|
||||
MockInMemoryChannel.assert_called_once()
|
||||
|
||||
# Verify GraphEngine was initialized with the InMemory channel
|
||||
MockGraphEngine.assert_called_once()
|
||||
call_args = MockGraphEngine.call_args[1]
|
||||
assert call_args["command_channel"] == mock_inmemory_channel
|
||||
assert workflow_entry.command_channel == mock_inmemory_channel
|
||||
|
||||
def test_workflow_entry_run_with_redis_channel(self):
|
||||
"""Test that WorkflowEntry.run() works correctly with Redis channel."""
|
||||
# Mock dependencies
|
||||
mock_graph = MagicMock()
|
||||
mock_graph_config = {"nodes": [], "edges": []}
|
||||
mock_variable_pool = MagicMock(spec=VariablePool)
|
||||
mock_graph_runtime_state = MagicMock(spec=GraphRuntimeState)
|
||||
mock_graph_runtime_state.variable_pool = mock_variable_pool
|
||||
|
||||
# Create a mock Redis channel
|
||||
mock_redis_client = MagicMock()
|
||||
redis_channel = RedisChannel(mock_redis_client, "test:channel:key")
|
||||
|
||||
# Mock events to be generated
|
||||
mock_event1 = MagicMock()
|
||||
mock_event2 = MagicMock()
|
||||
|
||||
# Patch GraphEngine
|
||||
with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine:
|
||||
mock_graph_engine = MagicMock()
|
||||
mock_graph_engine.run.return_value = iter([mock_event1, mock_event2])
|
||||
MockGraphEngine.return_value = mock_graph_engine
|
||||
|
||||
# Create WorkflowEntry with Redis channel
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
workflow_id="test-workflow",
|
||||
graph_config=mock_graph_config,
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
variable_pool=mock_variable_pool,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
command_channel=redis_channel,
|
||||
)
|
||||
|
||||
# Run the workflow
|
||||
events = list(workflow_entry.run())
|
||||
|
||||
# Verify events were generated
|
||||
assert len(events) == 2
|
||||
assert events[0] == mock_event1
|
||||
assert events[1] == mock_event2
|
||||
@ -1,7 +1,7 @@
|
||||
import dataclasses
|
||||
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
def test_extract_selectors_from_template():
|
||||
|
||||
@ -11,12 +11,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_success_with_all_config(self):
|
||||
"""Test successful initialization when all required config is provided."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -31,7 +31,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_url_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_URL is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = None
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -41,7 +41,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_api_key_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_API_KEY is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = None
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -51,7 +51,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_bucket_name_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = None
|
||||
@ -61,12 +61,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_create_bucket_when_not_exists(self):
|
||||
"""Test create_bucket creates bucket when it doesn't exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -77,12 +77,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_create_bucket_when_exists(self):
|
||||
"""Test create_bucket does not create bucket when it already exists."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -94,12 +94,12 @@ class TestSupabaseStorage:
|
||||
@pytest.fixture
|
||||
def storage_with_mock_client(self):
|
||||
"""Fixture providing SupabaseStorage with mocked client."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -251,12 +251,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_true_when_bucket_found(self):
|
||||
"""Test bucket_exists returns True when bucket is found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -271,12 +271,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_false_when_bucket_not_found(self):
|
||||
"""Test bucket_exists returns False when bucket is not found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
@ -294,12 +294,12 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_false_when_no_buckets(self):
|
||||
"""Test bucket_exists returns False when no buckets exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client") as mock_client_class:
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from hypothesis import given
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
@ -371,7 +371,7 @@ def test_build_segment_array_any_properties():
|
||||
# Test properties
|
||||
assert segment.text == str(mixed_values)
|
||||
assert segment.log == str(mixed_values)
|
||||
assert segment.markdown == "string\n42\nNone"
|
||||
assert segment.markdown == "- string\n- 42\n- None"
|
||||
assert segment.to_object() == mixed_values
|
||||
|
||||
|
||||
@ -486,13 +486,14 @@ def _generate_file(draw) -> File:
|
||||
def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
|
||||
return st.one_of(
|
||||
st.none(),
|
||||
st.integers(),
|
||||
st.floats(),
|
||||
st.text(),
|
||||
st.integers(min_value=-(10**6), max_value=10**6),
|
||||
st.floats(allow_nan=True, allow_infinity=False),
|
||||
st.text(max_size=50),
|
||||
_generate_file(),
|
||||
)
|
||||
|
||||
|
||||
@settings(max_examples=50)
|
||||
@given(_scalar_value())
|
||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
seg = variable_factory.build_segment(value)
|
||||
@ -503,7 +504,8 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
assert seg.value == value
|
||||
|
||||
|
||||
@given(st.lists(_scalar_value()))
|
||||
@settings(max_examples=50)
|
||||
@given(values=st.lists(_scalar_value(), max_size=20))
|
||||
def test_build_segment_and_extract_values_for_array_types(values):
|
||||
seg = variable_factory.build_segment(values)
|
||||
assert seg.value == values
|
||||
|
||||
@ -27,7 +27,7 @@ from services.feature_service import BrandingModel
|
||||
class MockEmailRenderer:
|
||||
"""Mock implementation of EmailRenderer protocol"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.rendered_templates: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
def render_template(self, template_path: str, **context: Any) -> str:
|
||||
@ -39,7 +39,7 @@ class MockEmailRenderer:
|
||||
class MockBrandingService:
|
||||
"""Mock implementation of BrandingService protocol"""
|
||||
|
||||
def __init__(self, enabled: bool = False, application_title: str = "Dify") -> None:
|
||||
def __init__(self, enabled: bool = False, application_title: str = "Dify"):
|
||||
self.enabled = enabled
|
||||
self.application_title = application_title
|
||||
|
||||
@ -54,10 +54,10 @@ class MockBrandingService:
|
||||
class MockEmailSender:
|
||||
"""Mock implementation of EmailSender protocol"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.sent_emails: list[dict[str, str]] = []
|
||||
|
||||
def send_email(self, to: str, subject: str, html_content: str) -> None:
|
||||
def send_email(self, to: str, subject: str, html_content: str):
|
||||
"""Mock send_email that records sent emails"""
|
||||
self.sent_emails.append(
|
||||
{
|
||||
@ -134,7 +134,7 @@ class TestEmailI18nService:
|
||||
email_service: EmailI18nService,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending email with English language"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
@ -162,7 +162,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending email with Chinese language"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
@ -181,7 +181,7 @@ class TestEmailI18nService:
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending email with branding enabled"""
|
||||
# Create branding service with branding enabled
|
||||
branding_service = MockBrandingService(enabled=True, application_title="MyApp")
|
||||
@ -215,7 +215,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test language fallback to English when requested language not available"""
|
||||
# Request invite member in Chinese (not configured)
|
||||
email_service.send_email(
|
||||
@ -233,7 +233,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test unknown language code falls back to English"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
@ -246,13 +246,50 @@ class TestEmailI18nService:
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "Reset Your Dify Password"
|
||||
|
||||
def test_subject_format_keyerror_fallback_path(
|
||||
self,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Trigger subject KeyError and cover except branch."""
|
||||
# Config with subject that references an unknown key (no {application_title} to avoid second format)
|
||||
config = EmailI18nConfig(
|
||||
templates={
|
||||
EmailType.INVITE_MEMBER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Invite: {unknown_placeholder}",
|
||||
template_path="invite_member_en.html",
|
||||
branded_template_path="branded/invite_member_en.html",
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
branding_service = MockBrandingService(enabled=False)
|
||||
service = EmailI18nService(
|
||||
config=config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
# Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback
|
||||
service.send_email(
|
||||
email_type=EmailType.INVITE_MEMBER,
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
)
|
||||
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
# Subject is left unformatted due to KeyError fallback path without application_title
|
||||
assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}"
|
||||
|
||||
def test_send_change_email_old_phase(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending change email for old email verification"""
|
||||
# Add change email templates to config
|
||||
email_config.templates[EmailType.CHANGE_EMAIL_OLD] = {
|
||||
@ -290,7 +327,7 @@ class TestEmailI18nService:
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending change email for new email verification"""
|
||||
# Add change email templates to config
|
||||
email_config.templates[EmailType.CHANGE_EMAIL_NEW] = {
|
||||
@ -325,7 +362,7 @@ class TestEmailI18nService:
|
||||
def test_send_change_email_invalid_phase(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending change email with invalid phase raises error"""
|
||||
with pytest.raises(ValueError, match="Invalid phase: invalid_phase"):
|
||||
email_service.send_change_email(
|
||||
@ -339,7 +376,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending raw email to single recipient"""
|
||||
email_service.send_raw_email(
|
||||
to="test@example.com",
|
||||
@ -357,7 +394,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending raw email to multiple recipients"""
|
||||
recipients = ["user1@example.com", "user2@example.com", "user3@example.com"]
|
||||
|
||||
@ -378,7 +415,7 @@ class TestEmailI18nService:
|
||||
def test_get_template_missing_email_type(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
) -> None:
|
||||
):
|
||||
"""Test getting template for missing email type raises error"""
|
||||
with pytest.raises(ValueError, match="No templates configured for email type"):
|
||||
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
|
||||
@ -386,7 +423,7 @@ class TestEmailI18nService:
|
||||
def test_get_template_missing_language_and_english(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
) -> None:
|
||||
):
|
||||
"""Test error when neither requested language nor English fallback exists"""
|
||||
# Add template without English fallback
|
||||
email_config.templates[EmailType.EMAIL_CODE_LOGIN] = {
|
||||
@ -407,7 +444,7 @@ class TestEmailI18nService:
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test subject templating with custom variables"""
|
||||
# Add template with variable in subject
|
||||
email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = {
|
||||
@ -437,7 +474,7 @@ class TestEmailI18nService:
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "You are now the owner of My Workspace"
|
||||
|
||||
def test_email_language_from_language_code(self) -> None:
|
||||
def test_email_language_from_language_code(self):
|
||||
"""Test EmailLanguage.from_language_code method"""
|
||||
assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS
|
||||
assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US
|
||||
@ -448,7 +485,7 @@ class TestEmailI18nService:
|
||||
class TestEmailI18nIntegration:
|
||||
"""Integration tests for email i18n components"""
|
||||
|
||||
def test_create_default_email_config(self) -> None:
|
||||
def test_create_default_email_config(self):
|
||||
"""Test creating default email configuration"""
|
||||
config = create_default_email_config()
|
||||
|
||||
@ -476,7 +513,7 @@ class TestEmailI18nIntegration:
|
||||
assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD]
|
||||
assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER]
|
||||
|
||||
def test_get_email_i18n_service(self) -> None:
|
||||
def test_get_email_i18n_service(self):
|
||||
"""Test getting global email i18n service instance"""
|
||||
service1 = get_email_i18n_service()
|
||||
service2 = get_email_i18n_service()
|
||||
@ -484,7 +521,7 @@ class TestEmailI18nIntegration:
|
||||
# Should return the same instance
|
||||
assert service1 is service2
|
||||
|
||||
def test_flask_email_renderer(self) -> None:
|
||||
def test_flask_email_renderer(self):
|
||||
"""Test FlaskEmailRenderer implementation"""
|
||||
renderer = FlaskEmailRenderer()
|
||||
|
||||
@ -494,7 +531,7 @@ class TestEmailI18nIntegration:
|
||||
with pytest.raises(TemplateNotFound):
|
||||
renderer.render_template("test.html", foo="bar")
|
||||
|
||||
def test_flask_mail_sender_not_initialized(self) -> None:
|
||||
def test_flask_mail_sender_not_initialized(self):
|
||||
"""Test FlaskMailSender when mail is not initialized"""
|
||||
sender = FlaskMailSender()
|
||||
|
||||
@ -514,7 +551,7 @@ class TestEmailI18nIntegration:
|
||||
# Restore original mail
|
||||
libs.email_i18n.mail = original_mail
|
||||
|
||||
def test_flask_mail_sender_initialized(self) -> None:
|
||||
def test_flask_mail_sender_initialized(self):
|
||||
"""Test FlaskMailSender when mail is initialized"""
|
||||
sender = FlaskMailSender()
|
||||
|
||||
|
||||
122
api/tests/unit_tests/libs/test_external_api.py
Normal file
122
api/tests/unit_tests/libs/test_external_api.py
Normal file
@ -0,0 +1,122 @@
|
||||
from flask import Blueprint, Flask
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, Unauthorized
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
|
||||
def _create_api_app():
|
||||
app = Flask(__name__)
|
||||
bp = Blueprint("t", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/bad-request")
|
||||
class Bad(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise BadRequest("invalid input")
|
||||
|
||||
@api.route("/unauth")
|
||||
class Unauth(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise Unauthorized("auth required")
|
||||
|
||||
@api.route("/value-error")
|
||||
class ValErr(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise ValueError("boom")
|
||||
|
||||
@api.route("/quota")
|
||||
class Quota(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise AppInvokeQuotaExceededError("quota exceeded")
|
||||
|
||||
@api.route("/general")
|
||||
class Gen(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise RuntimeError("oops")
|
||||
|
||||
# Note: We avoid altering default_mediatype to keep normal error paths
|
||||
|
||||
# Special 400 message rewrite
|
||||
@api.route("/json-empty")
|
||||
class JsonEmpty(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
e = BadRequest()
|
||||
# Force the specific message the handler rewrites
|
||||
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
||||
raise e
|
||||
|
||||
# 400 mapping payload path
|
||||
@api.route("/param-errors")
|
||||
class ParamErrors(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
e = BadRequest()
|
||||
# Coerce a mapping description to trigger param error shaping
|
||||
e.description = {"field": "is required"} # type: ignore[assignment]
|
||||
raise e
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
return app
|
||||
|
||||
|
||||
def test_external_api_error_handlers_basic_paths():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# 400
|
||||
res = client.get("/api/bad-request")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "bad_request"
|
||||
assert data["status"] == 400
|
||||
|
||||
# 401
|
||||
res = client.get("/api/unauth")
|
||||
assert res.status_code == 401
|
||||
assert "WWW-Authenticate" in res.headers
|
||||
|
||||
# 400 ValueError
|
||||
res = client.get("/api/value-error")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["code"] == "invalid_param"
|
||||
|
||||
# 500 general
|
||||
res = client.get("/api/general")
|
||||
assert res.status_code == 500
|
||||
assert res.get_json()["status"] == 500
|
||||
|
||||
|
||||
def test_external_api_json_message_and_bad_request_rewrite():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# JSON empty special rewrite
|
||||
res = client.get("/api/json-empty")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
55
api/tests/unit_tests/libs/test_file_utils.py
Normal file
55
api/tests/unit_tests/libs/test_file_utils.py
Normal file
@ -0,0 +1,55 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
|
||||
def test_search_file_upwards_found_in_parent(tmp_path: Path):
|
||||
base = tmp_path / "a" / "b" / "c"
|
||||
base.mkdir(parents=True)
|
||||
|
||||
target = tmp_path / "a" / "target.txt"
|
||||
target.write_text("ok", encoding="utf-8")
|
||||
|
||||
found = search_file_upwards(base, "target.txt", max_search_parent_depth=5)
|
||||
assert found == target
|
||||
|
||||
|
||||
def test_search_file_upwards_found_in_current(tmp_path: Path):
|
||||
base = tmp_path / "x"
|
||||
base.mkdir()
|
||||
target = base / "here.txt"
|
||||
target.write_text("x", encoding="utf-8")
|
||||
|
||||
found = search_file_upwards(base, "here.txt", max_search_parent_depth=1)
|
||||
assert found == target
|
||||
|
||||
|
||||
def test_search_file_upwards_not_found_raises(tmp_path: Path):
|
||||
base = tmp_path / "m" / "n"
|
||||
base.mkdir(parents=True)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
search_file_upwards(base, "missing.txt", max_search_parent_depth=3)
|
||||
# error message should contain file name and base path
|
||||
msg = str(exc.value)
|
||||
assert "missing.txt" in msg
|
||||
assert str(base) in msg
|
||||
|
||||
|
||||
def test_search_file_upwards_root_breaks_and_raises():
|
||||
# Using filesystem root triggers the 'break' branch (parent == current)
|
||||
with pytest.raises(ValueError):
|
||||
search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1)
|
||||
|
||||
|
||||
def test_search_file_upwards_depth_limit_raises(tmp_path: Path):
|
||||
base = tmp_path / "a" / "b" / "c"
|
||||
base.mkdir(parents=True)
|
||||
target = tmp_path / "a" / "target.txt"
|
||||
target.write_text("ok", encoding="utf-8")
|
||||
# The file is 2 levels up from `c` (in `a`), but search depth is only 2.
|
||||
# The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3).
|
||||
# So, this should not find the file and should raise an error.
|
||||
with pytest.raises(ValueError):
|
||||
search_file_upwards(base, "target.txt", max_search_parent_depth=2)
|
||||
@ -1,6 +1,5 @@
|
||||
import contextvars
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -29,7 +28,7 @@ def login_app(app: Flask) -> Flask:
|
||||
login_manager.init_app(app)
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id: str) -> Optional[User]:
|
||||
def load_user(user_id: str) -> User | None:
|
||||
if user_id == "test_user":
|
||||
return User("test_user")
|
||||
return None
|
||||
|
||||
88
api/tests/unit_tests/libs/test_json_in_md_parser.py
Normal file
88
api/tests/unit_tests/libs/test_json_in_md_parser.py
Normal file
@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from libs.json_in_md_parser import (
|
||||
parse_and_check_json_markdown,
|
||||
parse_json_markdown,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_json_markdown_triple_backticks_json():
|
||||
src = """
|
||||
```json
|
||||
{"a": 1, "b": "x"}
|
||||
```
|
||||
"""
|
||||
assert parse_json_markdown(src) == {"a": 1, "b": "x"}
|
||||
|
||||
|
||||
def test_parse_json_markdown_triple_backticks_generic():
|
||||
src = """
|
||||
```
|
||||
{"k": [1, 2, 3]}
|
||||
```
|
||||
"""
|
||||
assert parse_json_markdown(src) == {"k": [1, 2, 3]}
|
||||
|
||||
|
||||
def test_parse_json_markdown_single_backticks():
|
||||
src = '`{"x": true}`'
|
||||
assert parse_json_markdown(src) == {"x": True}
|
||||
|
||||
|
||||
def test_parse_json_markdown_braces_only():
|
||||
src = ' {\n \t"ok": "yes"\n} '
|
||||
assert parse_json_markdown(src) == {"ok": "yes"}
|
||||
|
||||
|
||||
def test_parse_json_markdown_not_found():
|
||||
with pytest.raises(ValueError):
|
||||
parse_json_markdown("no json here")
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_missing_key():
|
||||
src = """
|
||||
```
|
||||
{"present": 1}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as exc:
|
||||
parse_and_check_json_markdown(src, ["present", "missing"])
|
||||
assert "expected key `missing`" in str(exc.value)
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_invalid_json():
|
||||
src = """
|
||||
```json
|
||||
{invalid json}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as exc:
|
||||
parse_and_check_json_markdown(src, [])
|
||||
assert "got invalid json object" in str(exc.value)
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_success():
|
||||
src = """
|
||||
```json
|
||||
{"present": 1, "other": 2}
|
||||
```
|
||||
"""
|
||||
obj = parse_and_check_json_markdown(src, ["present"])
|
||||
assert obj == {"present": 1, "other": 2}
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_multiple_blocks_fails():
|
||||
src = """
|
||||
```json
|
||||
{"a": 1}
|
||||
```
|
||||
Some text
|
||||
```json
|
||||
{"b": 2}
|
||||
```
|
||||
"""
|
||||
# The current implementation is greedy and will match from the first
|
||||
# opening fence to the last closing fence, causing JSON decode failure.
|
||||
with pytest.raises(OutputParserError):
|
||||
parse_and_check_json_markdown(src, [])
|
||||
19
api/tests/unit_tests/libs/test_oauth_base.py
Normal file
19
api/tests/unit_tests/libs/test_oauth_base.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from libs.oauth import OAuth
|
||||
|
||||
|
||||
def test_oauth_base_methods_raise_not_implemented():
|
||||
oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_authorization_url()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_access_token("code")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_raw_user_info("token")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({}) # type: ignore[name-defined]
|
||||
@ -1,8 +1,8 @@
|
||||
import urllib.parse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
|
||||
@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
with pytest.raises(httpx.RequestError):
|
||||
oauth.get_raw_user_info("test_token")
|
||||
|
||||
|
||||
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type",
|
||||
[
|
||||
requests.exceptions.HTTPError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
httpx.HTTPError,
|
||||
httpx.ConnectError,
|
||||
httpx.TimeoutException,
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
|
||||
25
api/tests/unit_tests/libs/test_orjson.py
Normal file
25
api/tests/unit_tests/libs/test_orjson.py
Normal file
@ -0,0 +1,25 @@
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from libs.orjson import orjson_dumps
|
||||
|
||||
|
||||
def test_orjson_dumps_round_trip_basic():
|
||||
obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}}
|
||||
s = orjson_dumps(obj)
|
||||
assert orjson.loads(s) == obj
|
||||
|
||||
|
||||
def test_orjson_dumps_with_unicode_and_indent():
|
||||
obj = {"msg": "你好,Dify"}
|
||||
s = orjson_dumps(obj, option=orjson.OPT_INDENT_2)
|
||||
# contains indentation newline/spaces
|
||||
assert "\n" in s
|
||||
assert orjson.loads(s) == obj
|
||||
|
||||
|
||||
def test_orjson_dumps_non_utf8_encoding_fails():
|
||||
obj = {"msg": "你好"}
|
||||
# orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails.
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
orjson_dumps(obj, encoding="ascii")
|
||||
@ -4,7 +4,7 @@ from Crypto.PublicKey import RSA
|
||||
from libs import gmpy2_pkcs10aep_cipher
|
||||
|
||||
|
||||
def test_gmpy2_pkcs10aep_cipher() -> None:
|
||||
def test_gmpy2_pkcs10aep_cipher():
|
||||
rsa_key_pair = pyrsa.newkeys(2048)
|
||||
public_key = rsa_key_pair[0].save_pkcs1()
|
||||
private_key = rsa_key_pair[1].save_pkcs1()
|
||||
|
||||
53
api/tests/unit_tests/libs/test_sendgrid_client.py
Normal file
53
api/tests/unit_tests/libs/test_sendgrid_client.py
Normal file
@ -0,0 +1,53 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from python_http_client.exceptions import UnauthorizedError
|
||||
|
||||
from libs.sendgrid import SendGridClient
|
||||
|
||||
|
||||
def _mail(to: str = "user@example.com") -> dict:
|
||||
return {"to": to, "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_success(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
# nested attribute access: client.mail.send.post
|
||||
mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
sg.send(_mail())
|
||||
|
||||
mock_client_cls.assert_called_once()
|
||||
mock_client.client.mail.send.post.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock):
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(ValueError):
|
||||
sg.send(_mail(to=""))
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(UnauthorizedError):
|
||||
sg.send(_mail())
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = TimeoutError("timeout")
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(TimeoutError):
|
||||
sg.send(_mail())
|
||||
100
api/tests/unit_tests/libs/test_smtp_client.py
Normal file
100
api/tests/unit_tests/libs/test_smtp_client.py
Normal file
@ -0,0 +1,100 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.smtp import SMTPClient
|
||||
|
||||
|
||||
def _mail() -> dict:
|
||||
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=587,
|
||||
username="user",
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=True,
|
||||
)
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
|
||||
assert mock_smtp.ehlo.call_count == 2
|
||||
mock_smtp.starttls.assert_called_once()
|
||||
mock_smtp.login.assert_called_once_with("user", "pass")
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP_SSL")
|
||||
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
|
||||
# Cover SMTP_SSL branch and TimeoutError handling
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = TimeoutError("timeout")
|
||||
mock_smtp_ssl_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=465,
|
||||
username="",
|
||||
password="",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=False,
|
||||
)
|
||||
with pytest.raises(TimeoutError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = RuntimeError("oops")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
with pytest.raises(RuntimeError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
|
||||
# Ensure we hit the specific SMTPException except branch
|
||||
import smtplib
|
||||
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.login.side_effect = smtplib.SMTPException("login-fail")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=25,
|
||||
username="user", # non-empty to trigger login
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
)
|
||||
with pytest.raises(smtplib.SMTPException):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
@ -1,7 +1,7 @@
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
def test_account_is_privileged_role() -> None:
|
||||
def test_account_is_privileged_role():
|
||||
assert TenantAccountRole.ADMIN == "admin"
|
||||
assert TenantAccountRole.OWNER == "owner"
|
||||
assert TenantAccountRole.EDITOR == "editor"
|
||||
|
||||
83
api/tests/unit_tests/models/test_model.py
Normal file
83
api/tests/unit_tests/models/test_model.py
Normal file
@ -0,0 +1,83 @@
|
||||
import importlib
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import Message
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_file_helpers(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Patch file_helpers.get_signed_file_url to a deterministic stub.
|
||||
"""
|
||||
model_module = importlib.import_module("models.model")
|
||||
dummy = types.SimpleNamespace(get_signed_file_url=lambda fid: f"https://signed.example/{fid}")
|
||||
# Inject/override file_helpers on models.model
|
||||
monkeypatch.setattr(model_module, "file_helpers", dummy, raising=False)
|
||||
|
||||
|
||||
def _wrap_md(url: str) -> str:
|
||||
"""
|
||||
Wrap a raw URL into the markdown that re_sign_file_url_answer expects:
|
||||
[link](<url>)
|
||||
"""
|
||||
return f"please click [file]({url}) to download."
|
||||
|
||||
|
||||
def test_file_preview_valid_replaced():
|
||||
"""
|
||||
Valid file-preview URL must be re-signed:
|
||||
- Extract upload_file_id correctly
|
||||
- Replace the original URL with the signed URL
|
||||
"""
|
||||
upload_id = "abc-123"
|
||||
url = f"/files/{upload_id}/file-preview?timestamp=111&nonce=222&sign=333"
|
||||
msg = Message(answer=_wrap_md(url))
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
assert f"https://signed.example/{upload_id}" in out
|
||||
assert url not in out
|
||||
|
||||
|
||||
def test_file_preview_misspelled_not_replaced():
|
||||
"""
|
||||
Misspelled endpoint 'file-previe?timestamp=' should NOT be rewritten.
|
||||
"""
|
||||
upload_id = "zzz-001"
|
||||
# path deliberately misspelled: file-previe? (missing 'w')
|
||||
# and we append ¬e=file-preview to trick the old `"file-preview" in url` check.
|
||||
url = f"/files/{upload_id}/file-previe?timestamp=111&nonce=222&sign=333¬e=file-preview"
|
||||
original = _wrap_md(url)
|
||||
msg = Message(answer=original)
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
# Expect NO replacement, should not rewrite misspelled file-previe URL
|
||||
assert out == original
|
||||
|
||||
|
||||
def test_image_preview_valid_replaced():
|
||||
"""
|
||||
Valid image-preview URL must be re-signed.
|
||||
"""
|
||||
upload_id = "img-789"
|
||||
url = f"/files/{upload_id}/image-preview?timestamp=123&nonce=456&sign=789"
|
||||
msg = Message(answer=_wrap_md(url))
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
assert f"https://signed.example/{upload_id}" in out
|
||||
assert url not in out
|
||||
|
||||
|
||||
def test_image_preview_misspelled_not_replaced():
|
||||
"""
|
||||
Misspelled endpoint 'image-previe?timestamp=' should NOT be rewritten.
|
||||
"""
|
||||
upload_id = "img-err-42"
|
||||
url = f"/files/{upload_id}/image-previe?timestamp=1&nonce=2&sign=3¬e=image-preview"
|
||||
original = _wrap_md(url)
|
||||
msg = Message(answer=original)
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
# Expect NO replacement, should not rewrite misspelled image-previe URL
|
||||
assert out == original
|
||||
@ -154,7 +154,7 @@ class TestEnumText:
|
||||
TestCase(
|
||||
name="session insert with invalid type",
|
||||
action=lambda s: _session_insert_with_value(s, 1),
|
||||
exc_type=TypeError,
|
||||
exc_type=ValueError,
|
||||
),
|
||||
TestCase(
|
||||
name="insert with invalid value",
|
||||
@ -164,7 +164,7 @@ class TestEnumText:
|
||||
TestCase(
|
||||
name="insert with invalid type",
|
||||
action=lambda s: _insert_with_user(s, 1),
|
||||
exc_type=TypeError,
|
||||
exc_type=ValueError,
|
||||
),
|
||||
]
|
||||
for idx, c in enumerate(cases, 1):
|
||||
|
||||
@ -0,0 +1,212 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecutionOffload model, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModel:
|
||||
"""Test WorkflowNodeExecutionModel with process_data truncation features."""
|
||||
|
||||
def create_mock_offload_data(
|
||||
self,
|
||||
inputs_file_id: str | None = None,
|
||||
outputs_file_id: str | None = None,
|
||||
process_data_file_id: str | None = None,
|
||||
) -> WorkflowNodeExecutionOffload:
|
||||
"""Create a mock offload data object."""
|
||||
offload = Mock(spec=WorkflowNodeExecutionOffload)
|
||||
offload.inputs_file_id = inputs_file_id
|
||||
offload.outputs_file_id = outputs_file_id
|
||||
offload.process_data_file_id = process_data_file_id
|
||||
|
||||
# Mock file objects
|
||||
if inputs_file_id:
|
||||
offload.inputs_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.inputs_file = None
|
||||
|
||||
if outputs_file_id:
|
||||
offload.outputs_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.outputs_file = None
|
||||
|
||||
if process_data_file_id:
|
||||
offload.process_data_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.process_data_file = None
|
||||
|
||||
return offload
|
||||
|
||||
def test_process_data_truncated_property_false_when_no_offload_data(self):
|
||||
"""Test process_data_truncated returns False when no offload_data."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
execution.offload_data = []
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_process_data_truncated_property_false_when_no_process_data_file(self):
|
||||
"""Test process_data_truncated returns False when no process_data file."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create real offload instances for inputs and outputs but not process_data
|
||||
inputs_offload = WorkflowNodeExecutionOffload()
|
||||
inputs_offload.type_ = ExecutionOffLoadType.INPUTS
|
||||
inputs_offload.file_id = "inputs-file"
|
||||
|
||||
outputs_offload = WorkflowNodeExecutionOffload()
|
||||
outputs_offload.type_ = ExecutionOffLoadType.OUTPUTS
|
||||
outputs_offload.file_id = "outputs-file"
|
||||
|
||||
execution.offload_data = [inputs_offload, outputs_offload]
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_process_data_truncated_property_true_when_process_data_file_exists(self):
|
||||
"""Test process_data_truncated returns True when process_data file exists."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create a real offload instance for process_data
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "process-data-file-id"
|
||||
execution.offload_data = [process_data_offload]
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_load_full_process_data_with_no_offload_data(self):
|
||||
"""Test load_full_process_data when no offload data exists."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
execution.offload_data = []
|
||||
execution.process_data = '{"test": "data"}'
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
|
||||
def test_load_full_process_data_with_no_file(self):
|
||||
"""Test load_full_process_data when no process_data file exists."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create offload data for inputs only, not process_data
|
||||
inputs_offload = WorkflowNodeExecutionOffload()
|
||||
inputs_offload.type_ = ExecutionOffLoadType.INPUTS
|
||||
inputs_offload.file_id = "inputs-file"
|
||||
|
||||
execution.offload_data = [inputs_offload]
|
||||
execution.process_data = '{"test": "data"}'
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
|
||||
def test_load_full_process_data_with_file(self):
|
||||
"""Test load_full_process_data when process_data file exists."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create process_data offload
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "file-id"
|
||||
|
||||
execution.offload_data = [process_data_offload]
|
||||
execution.process_data = '{"truncated": "data"}'
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
# Mock the _load_full_content method to return full data
|
||||
full_process_data = {"full": "data", "large_field": "x" * 10000}
|
||||
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
# Mock the _load_full_content method
|
||||
def mock_load_full_content(session, file_id, storage):
|
||||
assert session == mock_session
|
||||
assert file_id == "file-id"
|
||||
assert storage == mock_storage
|
||||
return full_process_data
|
||||
|
||||
mp.setattr(execution, "_load_full_content", mock_load_full_content)
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == full_process_data
|
||||
|
||||
def test_consistency_with_inputs_outputs_truncation(self):
|
||||
"""Test that process_data truncation behaves consistently with inputs/outputs."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create offload data for all three types
|
||||
inputs_offload = WorkflowNodeExecutionOffload()
|
||||
inputs_offload.type_ = ExecutionOffLoadType.INPUTS
|
||||
inputs_offload.file_id = "inputs-file"
|
||||
|
||||
outputs_offload = WorkflowNodeExecutionOffload()
|
||||
outputs_offload.type_ = ExecutionOffLoadType.OUTPUTS
|
||||
outputs_offload.file_id = "outputs-file"
|
||||
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "process-data-file"
|
||||
|
||||
execution.offload_data = [inputs_offload, outputs_offload, process_data_offload]
|
||||
|
||||
# All three should be truncated
|
||||
assert execution.inputs_truncated is True
|
||||
assert execution.outputs_truncated is True
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_mixed_truncation_states(self):
|
||||
"""Test mixed states of truncation."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Only process_data is truncated
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "process-data-file"
|
||||
|
||||
execution.offload_data = [process_data_offload]
|
||||
|
||||
assert execution.inputs_truncated is False
|
||||
assert execution.outputs_truncated is False
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_preload_offload_data_and_files_method_exists(self):
|
||||
"""Test that the preload method includes process_data_file."""
|
||||
# This test verifies the method exists and can be called
|
||||
# The actual SQL behavior would be tested in integration tests
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(WorkflowNodeExecutionModel)
|
||||
|
||||
# This should not raise an exception
|
||||
preloaded_stmt = WorkflowNodeExecutionModel.preload_offload_data_and_files(stmt)
|
||||
|
||||
# The statement should be modified (different object)
|
||||
assert preloaded_stmt is not stmt
|
||||
@ -21,8 +21,11 @@ def get_example_filename() -> str:
|
||||
return "test.txt"
|
||||
|
||||
|
||||
def get_example_data() -> bytes:
|
||||
return b"test"
|
||||
def get_example_data(length: int = 4) -> bytes:
|
||||
chars = "test"
|
||||
result = "".join(chars[i % len(chars)] for i in range(length)).encode()
|
||||
assert len(result) == length
|
||||
return result
|
||||
|
||||
|
||||
def get_example_filepath() -> str:
|
||||
|
||||
@ -57,12 +57,19 @@ class TestOpenDAL:
|
||||
def test_load_stream(self):
|
||||
"""Test loading data as a stream."""
|
||||
filename = get_example_filename()
|
||||
data = get_example_data()
|
||||
chunks = 5
|
||||
chunk_size = 4096
|
||||
data = get_example_data(length=chunk_size * chunks)
|
||||
|
||||
self.storage.save(filename, data)
|
||||
generator = self.storage.load_stream(filename)
|
||||
assert isinstance(generator, Generator)
|
||||
assert next(generator) == data
|
||||
for i in range(chunks):
|
||||
fetched = next(generator)
|
||||
assert len(fetched) == chunk_size
|
||||
assert fetched == data[i * chunk_size : (i + 1) * chunk_size]
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
def test_download(self):
|
||||
"""Test downloading data to a file."""
|
||||
|
||||
@ -3,6 +3,7 @@ Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, PropertyMock
|
||||
@ -13,12 +14,14 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
from core.workflow.entities import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models.account import Account, Tenant
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
@ -85,26 +88,41 @@ def test_save(repository, session):
|
||||
"""Test save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.id = "test-id"
|
||||
execution.node_execution_id = "test-node-execution-id"
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
execution.workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
repository.to_db_model = MagicMock(return_value=execution)
|
||||
db_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
repository._to_db_model = MagicMock(return_value=db_model)
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
session_obj.get.return_value = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.merge was called (now using merge for both save and update)
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id)
|
||||
|
||||
# Assert session.add was called for new record
|
||||
session_obj.add.assert_called_once_with(db_model)
|
||||
|
||||
# Assert session.commit was called
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_with_existing_tenant_id(repository, session):
|
||||
@ -112,6 +130,8 @@ def test_save_with_existing_tenant_id(repository, session):
|
||||
session_obj, _ = session
|
||||
# Create a mock execution with existing tenant_id
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = "existing-id"
|
||||
execution.node_execution_id = "existing-node-execution-id"
|
||||
execution.tenant_id = "existing-tenant"
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
@ -121,20 +141,39 @@ def test_save_with_existing_tenant_id(repository, session):
|
||||
|
||||
# Create a modified execution that will be returned by _to_db_model
|
||||
modified_execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
modified_execution.id = "existing-id"
|
||||
modified_execution.node_execution_id = "existing-node-execution-id"
|
||||
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
|
||||
modified_execution.app_id = repository._app_id # App ID should be set
|
||||
# Create a dictionary to simulate __dict__ for updating attributes
|
||||
modified_execution.__dict__ = {
|
||||
"id": "existing-id",
|
||||
"node_execution_id": "existing-node-execution-id",
|
||||
"tenant_id": "existing-tenant",
|
||||
"app_id": repository._app_id,
|
||||
}
|
||||
|
||||
# Mock the to_db_model method to return the modified execution
|
||||
repository.to_db_model = MagicMock(return_value=modified_execution)
|
||||
repository._to_db_model = MagicMock(return_value=modified_execution)
|
||||
|
||||
# Mock session.get to return an existing record
|
||||
existing_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
session_obj.get.return_value = existing_model
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.merge was called with the modified execution (now using merge for both save and update)
|
||||
session_obj.merge.assert_called_once_with(modified_execution)
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id)
|
||||
|
||||
# Assert session.add was NOT called since we're updating existing
|
||||
session_obj.add.assert_not_called()
|
||||
|
||||
# Assert session.commit was called
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
@ -142,10 +181,19 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_asc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.asc")
|
||||
mock_desc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.desc")
|
||||
|
||||
mock_WorkflowNodeExecutionModel = mocker.patch(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel"
|
||||
)
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_asc.return_value = mock_stmt
|
||||
mock_desc.return_value = mock_stmt
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
@ -164,6 +212,7 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result contains our mock domain model
|
||||
@ -199,7 +248,7 @@ def test_to_db_model(repository):
|
||||
)
|
||||
|
||||
# Convert to DB model
|
||||
db_model = repository.to_db_model(domain_model)
|
||||
db_model = repository._to_db_model(domain_model)
|
||||
|
||||
# Assert DB model has correct values
|
||||
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
||||
|
||||
@ -0,0 +1,106 @@
|
||||
"""
|
||||
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType
|
||||
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
||||
"""Test process_data truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
def create_mock_account(self) -> Account:
|
||||
"""Create a mock Account for testing."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-user-id"
|
||||
account.tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
def create_mock_session_factory(self) -> sessionmaker:
|
||||
"""Create a mock session factory for testing."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_session_factory.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.return_value.__exit__.return_value = None
|
||||
return mock_session_factory
|
||||
|
||||
def create_repository(self, mock_file_service=None) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""Create a repository instance for testing."""
|
||||
mock_account = self.create_mock_account()
|
||||
mock_session_factory = self.create_mock_session_factory()
|
||||
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if mock_file_service:
|
||||
repository._file_service = mock_file_service
|
||||
|
||||
return repository
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution instance for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model without offload data."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create mock database model without offload data
|
||||
db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-execution-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.finished_at = None
|
||||
|
||||
process_data = {"normal": "data"}
|
||||
db_model.process_data_dict = process_data
|
||||
db_model.inputs_dict = None
|
||||
db_model.outputs_dict = None
|
||||
db_model.execution_metadata_dict = {}
|
||||
db_model.offload_data = None
|
||||
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Domain model should have the data from database
|
||||
assert domain_model.process_data == process_data
|
||||
|
||||
# Should not be truncated
|
||||
assert domain_model.process_data_truncated is False
|
||||
assert domain_model.get_truncated_process_data() is None
|
||||
@ -28,18 +28,20 @@ class TestApiKeyAuthService:
|
||||
mock_binding.provider = self.provider
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [mock_binding]
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].tenant_id == self.tenant_id
|
||||
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
|
||||
assert mock_session.scalars.call_count == 1
|
||||
select_arg = mock_session.scalars.call_args[0][0]
|
||||
assert "data_source_api_key_auth_binding" in str(select_arg).lower()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_empty(self, mock_session):
|
||||
"""Test get provider auth list - empty result"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
@ -48,13 +50,15 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_filters_disabled(self, mock_session):
|
||||
"""Test get provider auth list - filters disabled items"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
# Verify where conditions include disabled.is_(False)
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 2 # tenant_id and disabled filter conditions
|
||||
select_stmt = mock_session.scalars.call_args[0][0]
|
||||
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
|
||||
# Ensure both tenant filter and disabled filter exist
|
||||
where_strs = [str(c).lower() for c in where_clauses]
|
||||
assert any("tenant_id" in s for s in where_strs)
|
||||
assert any("disabled" in s for s in where_strs)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
|
||||
@ -6,8 +6,8 @@ import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@ -26,7 +26,7 @@ class TestAuthIntegration:
|
||||
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test complete authentication flow: request → validation → encryption → storage"""
|
||||
@ -47,7 +47,7 @@ class TestAuthIntegration:
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_cross_component_integration(self, mock_http):
|
||||
"""Test factory → provider → HTTP call integration"""
|
||||
mock_http.return_value = self._create_success_response()
|
||||
@ -63,10 +63,10 @@ class TestAuthIntegration:
|
||||
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
|
||||
|
||||
assert len(result1) == 1
|
||||
@ -97,7 +97,7 @@ class TestAuthIntegration:
|
||||
assert "another_secret" not in factory_str
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test concurrent authentication creation safety"""
|
||||
@ -142,31 +142,31 @@ class TestAuthIntegration:
|
||||
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
|
||||
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_http_error_handling(self, mock_http):
|
||||
"""Test proper HTTP error handling"""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = '{"error": "Unauthorized"}'
|
||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
|
||||
mock_http.return_value = mock_response
|
||||
|
||||
# PT012: Split into single statement for pytest.raises
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
with pytest.raises((requests.exceptions.HTTPError, Exception)):
|
||||
with pytest.raises((httpx.HTTPError, Exception)):
|
||||
factory.validate_credentials()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_network_failure_recovery(self, mock_http, mock_session):
|
||||
"""Test system recovery from network failures"""
|
||||
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
|
||||
mock_http.side_effect = httpx.RequestError("Network timeout")
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
with pytest.raises(httpx.RequestError):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user