mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -90,6 +90,7 @@ def test_flask_configs(monkeypatch):
|
||||
"pool_recycle": 3600,
|
||||
"pool_size": 30,
|
||||
"pool_use_lifo": False,
|
||||
"pool_reset_on_return": None,
|
||||
}
|
||||
|
||||
assert config["CONSOLE_WEB_URL"] == "https://example.com"
|
||||
|
||||
@ -0,0 +1,134 @@
|
||||
"""Test authentication security to prevent user enumeration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
import services.errors.account
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
from controllers.console.auth.login import LoginApi
|
||||
from controllers.console.error import AccountNotFound
|
||||
|
||||
|
||||
class TestAuthenticationSecurity:
|
||||
"""Test authentication endpoints for security against user enumeration."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.app = Flask(__name__)
|
||||
self.api = Api(self.app)
|
||||
self.api.add_resource(LoginApi, "/login")
|
||||
self.client = self.app.test_client()
|
||||
self.app.config["TESTING"] = True
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email sends reset password email when registration is allowed."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
result = login_api.post()
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "fail", "data": "token123", "code": "account_not_found"}
|
||||
mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_wrong_password_returns_error(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
|
||||
):
|
||||
"""Test that wrong password returns AuthenticationFailedError."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("existing@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email raises AccountNotFound when registration is disabled."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AccountNotFound) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "account_not_found"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
with self.app.test_request_context("/reset-password", method="POST", json={"email": "existing@example.com"}):
|
||||
from controllers.console.auth.login import ResetPasswordSendEmailApi
|
||||
|
||||
api = ResetPasswordSendEmailApi()
|
||||
result = api.post()
|
||||
|
||||
assert result == {"result": "success", "data": "token123"}
|
||||
0
api/tests/unit_tests/core/plugin/__init__.py
Normal file
0
api/tests/unit_tests/core/plugin/__init__.py
Normal file
0
api/tests/unit_tests/core/plugin/utils/__init__.py
Normal file
0
api/tests/unit_tests/core/plugin/utils/__init__.py
Normal file
460
api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
Normal file
460
api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
Normal file
@ -0,0 +1,460 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class TestChunkMerger:
|
||||
def test_file_chunk_initialization(self):
|
||||
"""Test FileChunk initialization."""
|
||||
chunk = FileChunk(1024)
|
||||
assert chunk.bytes_written == 0
|
||||
assert chunk.total_length == 1024
|
||||
assert len(chunk.data) == 1024
|
||||
|
||||
def test_merge_blob_chunks_with_single_complete_chunk(self):
|
||||
"""Test merging a single complete blob chunk."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# First chunk (partial)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
|
||||
),
|
||||
)
|
||||
# Second chunk (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=10, blob=b"World", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# The buffer should contain the complete data
|
||||
assert result[0].message.blob[:10] == b"HelloWorld"
|
||||
|
||||
def test_merge_blob_chunks_with_multiple_files(self):
|
||||
"""Test merging chunks from multiple files."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# File 1, chunk 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=4, blob=b"AB", end=False
|
||||
),
|
||||
)
|
||||
# File 2, chunk 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file2", sequence=0, total_length=4, blob=b"12", end=False
|
||||
),
|
||||
)
|
||||
# File 1, chunk 2 (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=4, blob=b"CD", end=True
|
||||
),
|
||||
)
|
||||
# File 2, chunk 2 (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file2", sequence=1, total_length=4, blob=b"34", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 2
|
||||
# Check that both files are properly merged
|
||||
assert all(r.type == ToolInvokeMessage.MessageType.BLOB for r in result)
|
||||
|
||||
def test_merge_blob_chunks_passes_through_non_blob_messages(self):
|
||||
"""Test that non-blob messages pass through unchanged."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Text message
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="Hello"),
|
||||
)
|
||||
# Blob chunk
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=5, blob=b"Test", end=True
|
||||
),
|
||||
)
|
||||
# Another text message
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="World"),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 3
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.TextMessage)
|
||||
assert result[0].message.text == "Hello"
|
||||
assert result[1].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert result[2].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(result[2].message, ToolInvokeMessage.TextMessage)
|
||||
assert result[2].message.text == "World"
|
||||
|
||||
def test_merge_blob_chunks_file_too_large(self):
|
||||
"""Test that error is raised when file exceeds max size."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Send a chunk that would exceed the limit
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=100, blob=b"x" * 1024, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
list(merge_blob_chunks(mock_generator(), max_file_size=1000))
|
||||
assert "File is too large" in str(exc_info.value)
|
||||
|
||||
def test_merge_blob_chunks_chunk_too_large(self):
|
||||
"""Test that error is raised when chunk exceeds max chunk size."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Send a chunk that exceeds the max chunk size
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10000, blob=b"x" * 9000, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
list(merge_blob_chunks(mock_generator(), max_chunk_size=8192))
|
||||
assert "File chunk is too large" in str(exc_info.value)
|
||||
|
||||
def test_merge_blob_chunks_with_agent_invoke_message(self):
|
||||
"""Test that merge_blob_chunks works with AgentInvokeMessage."""
|
||||
|
||||
def mock_generator() -> Generator[AgentInvokeMessage, None, None]:
|
||||
# First chunk
|
||||
yield AgentInvokeMessage(
|
||||
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=AgentInvokeMessage.BlobChunkMessage(
|
||||
id="agent_file", sequence=0, total_length=8, blob=b"Agent", end=False
|
||||
),
|
||||
)
|
||||
# Final chunk
|
||||
yield AgentInvokeMessage(
|
||||
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=AgentInvokeMessage.BlobChunkMessage(
|
||||
id="agent_file", sequence=1, total_length=8, blob=b"Data", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], AgentInvokeMessage)
|
||||
assert result[0].type == AgentInvokeMessage.MessageType.BLOB
|
||||
|
||||
def test_merge_blob_chunks_preserves_meta(self):
|
||||
"""Test that meta information is preserved in merged messages."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=4, blob=b"Test", end=True
|
||||
),
|
||||
meta={"key": "value"},
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].meta == {"key": "value"}
|
||||
|
||||
def test_merge_blob_chunks_custom_limits(self):
|
||||
"""Test merge_blob_chunks with custom size limits."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# This should work with custom limits
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
|
||||
),
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=500, blob=b"y" * 100, end=True
|
||||
),
|
||||
)
|
||||
|
||||
# Should work with custom limits
|
||||
result = list(merge_blob_chunks(mock_generator(), max_file_size=1000, max_chunk_size=500))
|
||||
assert len(result) == 1
|
||||
|
||||
# Should fail with smaller file size limit
|
||||
def mock_generator2() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(merge_blob_chunks(mock_generator2(), max_file_size=300))
|
||||
|
||||
def test_merge_blob_chunks_data_integrity(self):
|
||||
"""Test that merged chunks exactly match the original data."""
|
||||
# Create original data
|
||||
original_data = b"This is a test message that will be split into chunks for testing purposes."
|
||||
chunk_size = 20
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Split original data into chunks
|
||||
chunks = []
|
||||
for i in range(0, len(original_data), chunk_size):
|
||||
chunk_data = original_data[i : i + chunk_size]
|
||||
is_last = (i + chunk_size) >= len(original_data)
|
||||
chunks.append((i // chunk_size, chunk_data, is_last))
|
||||
|
||||
# Yield chunks
|
||||
for sequence, data, is_end in chunks:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="test_file",
|
||||
sequence=sequence,
|
||||
total_length=len(original_data),
|
||||
blob=data,
|
||||
end=is_end,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# Verify the merged data exactly matches the original
|
||||
assert result[0].message.blob == original_data
|
||||
|
||||
def test_merge_blob_chunks_empty_chunk(self):
|
||||
"""Test handling of empty chunks."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# First chunk with data
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
|
||||
),
|
||||
)
|
||||
# Empty chunk in the middle
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=10, blob=b"", end=False
|
||||
),
|
||||
)
|
||||
# Final chunk with data
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=2, total_length=10, blob=b"World", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# The final blob should contain "Hello" followed by "World"
|
||||
assert result[0].message.blob[:10] == b"HelloWorld"
|
||||
|
||||
def test_merge_blob_chunks_single_chunk_file(self):
|
||||
"""Test file that arrives as a single complete chunk."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Single chunk that is both first and last
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="single_chunk_file",
|
||||
sequence=0,
|
||||
total_length=11,
|
||||
blob=b"Single Data",
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert result[0].message.blob == b"Single Data"
|
||||
|
||||
def test_merge_blob_chunks_concurrent_files(self):
|
||||
"""Test that chunks from different files are properly separated."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Interleave chunks from three different files
|
||||
files_data = {
|
||||
"file1": b"First file content",
|
||||
"file2": b"Second file data",
|
||||
"file3": b"Third file",
|
||||
}
|
||||
|
||||
# First chunk from each file
|
||||
for file_id, data in files_data.items():
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id=file_id,
|
||||
sequence=0,
|
||||
total_length=len(data),
|
||||
blob=data[:6],
|
||||
end=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Second chunk from each file (final)
|
||||
for file_id, data in files_data.items():
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id=file_id,
|
||||
sequence=1,
|
||||
total_length=len(data),
|
||||
blob=data[6:],
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 3
|
||||
|
||||
# Extract the blob data from results
|
||||
blobs = set()
|
||||
for r in result:
|
||||
assert isinstance(r.message, ToolInvokeMessage.BlobMessage)
|
||||
blobs.add(r.message.blob)
|
||||
expected = {b"First file content", b"Second file data", b"Third file"}
|
||||
assert blobs == expected
|
||||
|
||||
def test_merge_blob_chunks_exact_buffer_size(self):
|
||||
"""Test that data fitting exactly in buffer works correctly."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Create data that exactly fills the declared buffer
|
||||
exact_data = b"X" * 100
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="exact_file",
|
||||
sequence=0,
|
||||
total_length=100,
|
||||
blob=exact_data[:50],
|
||||
end=False,
|
||||
),
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="exact_file",
|
||||
sequence=1,
|
||||
total_length=100,
|
||||
blob=exact_data[50:],
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert len(result[0].message.blob) == 100
|
||||
assert result[0].message.blob == b"X" * 100
|
||||
|
||||
def test_merge_blob_chunks_large_file_simulation(self):
|
||||
"""Test handling of a large file split into many chunks."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Simulate a 1MB file split into 128 chunks of 8KB each
|
||||
chunk_size = 8192
|
||||
num_chunks = 128
|
||||
total_size = chunk_size * num_chunks
|
||||
|
||||
for i in range(num_chunks):
|
||||
# Create unique data for each chunk to verify ordering
|
||||
chunk_data = bytes([i % 256]) * chunk_size
|
||||
is_last = i == num_chunks - 1
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="large_file",
|
||||
sequence=i,
|
||||
total_length=total_size,
|
||||
blob=chunk_data,
|
||||
end=is_last,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert len(result[0].message.blob) == 1024 * 1024
|
||||
|
||||
# Verify the data pattern is correct
|
||||
merged_data = result[0].message.blob
|
||||
chunk_size = 8192
|
||||
num_chunks = 128
|
||||
for i in range(num_chunks):
|
||||
chunk_start = i * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
expected_byte = i % 256
|
||||
chunk = merged_data[chunk_start:chunk_end]
|
||||
assert all(b == expected_byte for b in chunk), f"Chunk {i} has incorrect data"
|
||||
|
||||
def test_merge_blob_chunks_sequential_order_required(self):
|
||||
"""
|
||||
Test note: The current implementation assumes chunks arrive in sequential order.
|
||||
Out-of-order chunks would need additional logic to handle properly.
|
||||
This test documents the expected behavior with sequential chunks.
|
||||
"""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Chunks arriving in correct sequential order
|
||||
data_parts = [b"First", b"Second", b"Third"]
|
||||
total_length = sum(len(part) for part in data_parts)
|
||||
|
||||
for i, part in enumerate(data_parts):
|
||||
is_last = i == len(data_parts) - 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="ordered_file",
|
||||
sequence=i,
|
||||
total_length=total_length,
|
||||
blob=part,
|
||||
end=is_last,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert result[0].message.blob == b"FirstSecondThird"
|
||||
308
api/tests/unit_tests/core/test_provider_configuration.py
Normal file
308
api/tests/unit_tests/core/test_provider_configuration.py
Normal file
@ -0,0 +1,308 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
ModelSettings,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
RestrictModel,
|
||||
SystemConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity():
|
||||
"""Mock provider entity with basic configuration"""
|
||||
provider_entity = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"),
|
||||
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
background="background.png",
|
||||
help=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
)
|
||||
|
||||
return provider_entity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_system_configuration():
|
||||
"""Mock system configuration"""
|
||||
quota_config = QuotaConfiguration(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=1000,
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)],
|
||||
)
|
||||
|
||||
system_config = SystemConfiguration(
|
||||
enabled=True,
|
||||
credentials={"openai_api_key": "test_key"},
|
||||
quota_configurations=[quota_config],
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
|
||||
return system_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_configuration():
|
||||
"""Mock custom configuration"""
|
||||
custom_config = CustomConfiguration(provider=None, models=[])
|
||||
return custom_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration):
|
||||
"""Create a test provider configuration instance"""
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
return ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfiguration:
|
||||
"""Test cases for ProviderConfiguration class"""
|
||||
|
||||
def test_get_current_credentials_system_provider_success(self, provider_configuration):
|
||||
"""Test successfully getting credentials from system provider"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_get_current_credentials_model_disabled(self, provider_configuration):
|
||||
"""Test getting credentials when model is disabled"""
|
||||
# Arrange
|
||||
model_setting = ModelSettings(
|
||||
model="gpt-4",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=False,
|
||||
load_balancing_configs=[],
|
||||
has_invalid_load_balancing_configs=False,
|
||||
)
|
||||
provider_configuration.model_settings = [model_setting]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Model gpt-4 is disabled"):
|
||||
provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
def test_get_current_credentials_custom_provider_with_models(self, provider_configuration):
|
||||
"""Test getting credentials from custom provider with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.CUSTOM
|
||||
|
||||
mock_model_config = Mock()
|
||||
mock_model_config.model_type = ModelType.LLM
|
||||
mock_model_config.model = "gpt-4"
|
||||
mock_model_config.credentials = {"openai_api_key": "custom_key"}
|
||||
provider_configuration.custom_configuration.models = [mock_model_config]
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "custom_key"}
|
||||
|
||||
def test_get_system_configuration_status_active(self, provider_configuration):
|
||||
"""Test getting active system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.ACTIVE
|
||||
|
||||
def test_get_system_configuration_status_unsupported(self, provider_configuration):
|
||||
"""Test getting unsupported system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
def test_get_system_configuration_status_quota_exceeded(self, provider_configuration):
|
||||
"""Test getting quota exceeded system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
quota_config = provider_configuration.system_configuration.quota_configurations[0]
|
||||
quota_config.is_valid = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
|
||||
def test_is_custom_configuration_available_with_provider(self, provider_configuration):
|
||||
"""Test custom configuration availability with provider credentials"""
|
||||
# Arrange
|
||||
mock_provider = Mock()
|
||||
mock_provider.available_credentials = ["openai_api_key"]
|
||||
provider_configuration.custom_configuration.provider = mock_provider
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_with_models(self, provider_configuration):
|
||||
"""Test custom configuration availability with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = [Mock()]
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_false(self, provider_configuration):
|
||||
"""Test custom configuration not available"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record successfully"""
|
||||
# Arrange
|
||||
mock_provider = Mock(spec=Provider)
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result == mock_provider
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record when not found"""
|
||||
# Arrange
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_init_with_customizable_model_only(
|
||||
self, mock_provider_entity, mock_system_configuration, mock_custom_configuration
|
||||
):
|
||||
"""Test initialization with customizable model only configuration"""
|
||||
# Arrange
|
||||
mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL]
|
||||
|
||||
# Act
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
config = ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods
|
||||
|
||||
def test_get_current_credentials_with_restricted_models(self, provider_configuration):
|
||||
"""Test getting credentials with model restrictions"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo")
|
||||
|
||||
# Assert
|
||||
assert credentials is not None
|
||||
assert "openai_api_key" in credentials
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_success(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential successfully"""
|
||||
# Arrange
|
||||
credential_id = "test_credential_id"
|
||||
mock_credential = Mock()
|
||||
mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}'
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential
|
||||
|
||||
# Act
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = {"openai_api_key": "test_key"}
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
|
||||
# Assert
|
||||
assert result == {"openai_api_key": "test_key"}
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential when not found"""
|
||||
# Arrange
|
||||
credential_id = "nonexistent_credential_id"
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = None
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
assert result is None
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
@ -1,190 +1,185 @@
|
||||
# from core.entities.provider_entities import ModelSettings
|
||||
# from core.model_runtime.entities.model_entities import ModelType
|
||||
# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
# from core.provider_manager import ProviderManager
|
||||
# from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
# def test__to_model_settings(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mock_entity.model_credential_schema = mocker.Mock()
|
||||
mock_entity.model_credential_schema.credential_form_schemas = []
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=True,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# ),
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id2",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="first",
|
||||
# encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
# enabled=True,
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(provider_entity,
|
||||
# provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 2
|
||||
# assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
# assert result[0].load_balancing_configs[1].name == "first"
|
||||
return mock_entity
|
||||
|
||||
|
||||
# def test__to_model_settings_only_one_lb(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
def test__to_model_settings(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=True,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# )
|
||||
# ]
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(
|
||||
# provider_entity, provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 0
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 2
|
||||
assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
# def test__to_model_settings_lb_disabled(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
)
|
||||
]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=False,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# ),
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id2",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="first",
|
||||
# encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
# enabled=True,
|
||||
# ),
|
||||
# ]
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
# return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(provider_entity,
|
||||
# provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 0
|
||||
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
@ -23,6 +23,7 @@ class TestSegmentTypeIsArrayType:
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
@ -34,6 +35,7 @@ class TestSegmentTypeIsArrayType:
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
SegmentType.BOOLEAN,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
|
||||
@ -0,0 +1,729 @@
|
||||
"""
|
||||
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
|
||||
|
||||
This module provides thorough testing of the validation logic for all SegmentType values,
|
||||
including edge cases, error conditions, and different ArrayValidation strategies.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
def create_test_file(
|
||||
file_type: FileType = FileType.DOCUMENT,
|
||||
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||
filename: str = "test.txt",
|
||||
extension: str = ".txt",
|
||||
mime_type: str = "text/plain",
|
||||
size: int = 1024,
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing."""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationTestCase:
|
||||
"""Test case data structure for validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArrayValidationTestCase:
|
||||
"""Test case data structure for array validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
array_validation: ArrayValidation
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
# Test data construction functions
|
||||
def get_boolean_cases() -> list[ValidationTestCase]:
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
|
||||
# Invalid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_number_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid number values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
|
||||
# invalid number values
|
||||
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
|
||||
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
|
||||
]
|
||||
|
||||
|
||||
def get_string_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid string values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
|
||||
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
|
||||
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
|
||||
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
|
||||
# invalid values
|
||||
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_object_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid object values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
|
||||
]
|
||||
|
||||
|
||||
def get_secret_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid secret values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
|
||||
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_file_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid file values."""
|
||||
test_file = create_test_file()
|
||||
image_file = create_test_file(
|
||||
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
|
||||
)
|
||||
remote_file = create_test_file(
|
||||
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
|
||||
)
|
||||
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
|
||||
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
|
||||
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
|
||||
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
|
||||
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
|
||||
]
|
||||
|
||||
|
||||
def get_none_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid none values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
|
||||
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
|
||||
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
|
||||
]
|
||||
|
||||
|
||||
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_ANY validation."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"Mixed types with FIRST validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Mixed types with ALL validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", "world"],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Valid strings with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, 456],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Invalid elements with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["valid", 123, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", 123, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid, others invalid",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, "hello", "world"],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid, others valid",
|
||||
),
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
[float("inf"), float("-inf"), float("nan")],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Special float values",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "not an object"],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid object",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
["not an object", {"valid": "object"}],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid",
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{}, {"a": 1}, {"nested": {"key": "value"}}],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"All valid objects",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "invalid", {"another": "object"}],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_FILE validation with different strategies."""
|
||||
file1 = create_test_file(filename="file1.txt")
|
||||
file2 = create_test_file(filename="file2.txt")
|
||||
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
[True, "false", False],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element (string)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestSegmentTypeIsValid:
|
||||
"""Test suite for SegmentType.is_valid method covering all non-array types."""
|
||||
|
||||
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
|
||||
def test_boolean_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
|
||||
def test_number_validation(self, case: ValidationTestCase):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
|
||||
def test_string_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
|
||||
def test_object_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
|
||||
def test_secret_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
|
||||
def test_file_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
|
||||
def test_none_validation_valid_cases(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
def test_unsupported_segment_type_raises_assertion_error(self):
|
||||
"""Test that unsupported SegmentType values raise AssertionError."""
|
||||
# GROUP is not handled in is_valid method
|
||||
with pytest.raises(AssertionError, match="this statement should be unreachable"):
|
||||
SegmentType.GROUP.is_valid("any value")
|
||||
|
||||
|
||||
class TestSegmentTypeArrayValidation:
|
||||
"""Test suite for SegmentType._validate_array method and array type validation."""
|
||||
|
||||
def test_array_validation_non_list_values(self):
|
||||
"""Test that non-list values return False for all array types."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
non_list_values = [
|
||||
"not a list",
|
||||
123,
|
||||
3.14,
|
||||
True,
|
||||
None,
|
||||
{"key": "value"},
|
||||
create_test_file(),
|
||||
]
|
||||
|
||||
for array_type in array_types:
|
||||
for value in non_list_values:
|
||||
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
|
||||
|
||||
def test_empty_array_validation(self):
|
||||
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
|
||||
|
||||
for array_type in array_types:
|
||||
for strategy in validation_strategies:
|
||||
assert array_type.is_valid([], strategy) is True, (
|
||||
f"{array_type} should accept empty array with {strategy}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_any_validation(self, case):
|
||||
"""Test ARRAY_ANY validation accepts any list regardless of content."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_none_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_first_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_all_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_number_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_NUMBER validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_object_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_OBJECT validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_file_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_FILE validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_boolean_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
def test_default_array_validation_strategy(self):
|
||||
"""Test that default array validation strategy is FIRST."""
|
||||
# When no array_validation parameter is provided, it should default to FIRST
|
||||
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
|
||||
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
|
||||
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
|
||||
|
||||
def test_array_validation_edge_cases(self):
|
||||
"""Test edge cases for array validation."""
|
||||
# Test with nested arrays (should be invalid for specific array types)
|
||||
nested_array = [["nested", "array"], ["another", "nested"]]
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
|
||||
|
||||
# Test with very large arrays (performance consideration)
|
||||
large_valid_array = ["string"] * 1000
|
||||
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
|
||||
|
||||
|
||||
class TestSegmentTypeValidationIntegration:
|
||||
"""Integration tests for SegmentType validation covering interactions between methods."""
|
||||
|
||||
def test_non_array_types_ignore_array_validation_parameter(self):
|
||||
"""Test that non-array types ignore the array_validation parameter."""
|
||||
non_array_types = [
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
]
|
||||
|
||||
for segment_type in non_array_types:
|
||||
# Create appropriate valid value for each type
|
||||
valid_value: Any
|
||||
if segment_type == SegmentType.STRING:
|
||||
valid_value = "test"
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
valid_value = 42
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
valid_value = True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
valid_value = {"key": "value"}
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
valid_value = "secret"
|
||||
elif segment_type == SegmentType.FILE:
|
||||
valid_value = create_test_file()
|
||||
elif segment_type == SegmentType.NONE:
|
||||
valid_value = None
|
||||
else:
|
||||
continue # Skip unsupported types
|
||||
|
||||
# All array validation strategies should give the same result
|
||||
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
|
||||
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
|
||||
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
|
||||
|
||||
assert result_none == result_first == result_all == True, (
|
||||
f"{segment_type} should ignore array_validation parameter"
|
||||
)
|
||||
|
||||
def test_comprehensive_type_coverage(self):
|
||||
"""Test that all SegmentType enum values are covered in validation tests."""
|
||||
all_segment_types = set(SegmentType)
|
||||
|
||||
# Types that should be handled by is_valid method
|
||||
handled_types = {
|
||||
# Non-array types
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
}
|
||||
|
||||
# Types that are not handled by is_valid (should raise AssertionError)
|
||||
unhandled_types = {
|
||||
SegmentType.GROUP,
|
||||
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
||||
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
||||
}
|
||||
|
||||
# Verify all types are accounted for
|
||||
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
|
||||
|
||||
# Test that handled types work correctly
|
||||
for segment_type in handled_types:
|
||||
if segment_type.is_array_type():
|
||||
# Test with empty array (should always be valid)
|
||||
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
|
||||
else:
|
||||
# Test with appropriate valid value
|
||||
if segment_type == SegmentType.STRING:
|
||||
assert segment_type.is_valid("test") is True
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
assert segment_type.is_valid(42) is True
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
assert segment_type.is_valid(True) is True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
assert segment_type.is_valid({}) is True
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
assert segment_type.is_valid("secret") is True
|
||||
elif segment_type == SegmentType.FILE:
|
||||
assert segment_type.is_valid(create_test_file()) is True
|
||||
elif segment_type == SegmentType.NONE:
|
||||
assert segment_type.is_valid(None) is True
|
||||
|
||||
def test_boolean_vs_integer_type_distinction(self):
|
||||
"""Test the important distinction between boolean and integer types in validation."""
|
||||
# This tests the comment in the code about bool being a subclass of int
|
||||
|
||||
# Boolean type should only accept actual booleans, not integers
|
||||
assert SegmentType.BOOLEAN.is_valid(True) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(False) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
|
||||
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
|
||||
|
||||
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
|
||||
assert SegmentType.NUMBER.is_valid(42) is True
|
||||
assert SegmentType.NUMBER.is_valid(3.14) is True
|
||||
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
|
||||
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
|
||||
|
||||
def test_array_validation_recursive_behavior(self):
|
||||
"""Test that array validation correctly handles recursive validation calls."""
|
||||
# When validating array elements, _validate_array calls is_valid recursively
|
||||
# with ArrayValidation.NONE to avoid infinite recursion
|
||||
|
||||
# Test nested validation doesn't cause issues
|
||||
nested_arrays = [["inner", "array"], ["another", "inner"]]
|
||||
|
||||
# ARRAY_ANY should accept nested arrays
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
|
||||
|
||||
# ARRAY_STRING should reject nested arrays (first element is not a string)
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False
|
||||
@ -0,0 +1,27 @@
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig
|
||||
|
||||
|
||||
class TestParameterConfig:
|
||||
def test_select_type(self):
|
||||
data = {
|
||||
"name": "yes_or_no",
|
||||
"type": "select",
|
||||
"options": ["yes", "no"],
|
||||
"description": "a simple select made of `yes` and `no`",
|
||||
"required": True,
|
||||
}
|
||||
|
||||
pc = ParameterConfig.model_validate(data)
|
||||
assert pc.type == SegmentType.STRING
|
||||
assert pc.options == data["options"]
|
||||
|
||||
def test_validate_bool_type(self):
|
||||
data = {
|
||||
"name": "boolean",
|
||||
"type": "bool",
|
||||
"description": "a simple boolean parameter",
|
||||
"required": True,
|
||||
}
|
||||
pc = ParameterConfig.model_validate(data)
|
||||
assert pc.type == SegmentType.BOOLEAN
|
||||
@ -0,0 +1,567 @@
|
||||
"""
|
||||
Test cases for ParameterExtractorNode._validate_result and _transform_result methods.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData
|
||||
from core.workflow.nodes.parameter_extractor.exc import (
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidSelectValueError,
|
||||
InvalidValueTypeError,
|
||||
RequiredParameterMissingError,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidTestCase:
|
||||
"""Test case data for valid scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
result: dict[str, Any]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorTestCase:
|
||||
"""Test case data for error scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
result: dict[str, Any]
|
||||
expected_exception: type[Exception]
|
||||
expected_message: str
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformTestCase:
|
||||
"""Test case data for transformation scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
input_result: dict[str, Any]
|
||||
expected_result: dict[str, Any]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class TestParameterExtractorNodeMethods:
|
||||
"""Test helper class that provides access to the methods under test."""
|
||||
|
||||
def validate_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Wrapper to call _validate_result method."""
|
||||
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
|
||||
return node._validate_result(data=data, result=result)
|
||||
|
||||
def transform_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Wrapper to call _transform_result method."""
|
||||
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
|
||||
return node._transform_result(data=data, result=result)
|
||||
|
||||
|
||||
class TestValidateResult:
|
||||
"""Test cases for _validate_result method."""
|
||||
|
||||
@staticmethod
|
||||
def get_valid_test_cases() -> list[ValidTestCase]:
|
||||
"""Get test cases that should pass validation."""
|
||||
return [
|
||||
ValidTestCase(
|
||||
name="single_string_parameter",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
result={"name": "John"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_number_parameter_int",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
result={"age": 25},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_number_parameter_float",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
result={"price": 19.99},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_true",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_true",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_false",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": False},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="select_parameter_valid_option",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # pyright: ignore[reportArgumentType]
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
result={"status": "active"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_string_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": ["tag1", "tag2", "tag3"]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_number_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
result={"scores": [85, 92.5, 78]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_object_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
result={"items": [{"name": "item1"}, {"name": "item2"}]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="multiple_parameters",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
|
||||
],
|
||||
result={"name": "John", "age": 25, "active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="optional_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="nickname", type=SegmentType.STRING, description="Nickname", required=False),
|
||||
],
|
||||
result={"name": "John", "nickname": "Johnny"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="empty_array_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": []},
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_error_test_cases() -> list[ErrorTestCase]:
|
||||
"""Get test cases that should raise exceptions."""
|
||||
return [
|
||||
ErrorTestCase(
|
||||
name="invalid_number_of_parameters_too_few",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
],
|
||||
result={"name": "John"},
|
||||
expected_exception=InvalidNumberOfParametersError,
|
||||
expected_message="Invalid number of parameters",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_number_of_parameters_too_many",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
result={"name": "John", "age": 25},
|
||||
expected_exception=InvalidNumberOfParametersError,
|
||||
expected_message="Invalid number of parameters",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_string_value_none",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
],
|
||||
result={"name": None}, # Parameter present but None value, will trigger type check first
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message="Invalid value for parameter name, expected segment type: string, actual_type: none",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_select_value",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
result={"status": "pending"},
|
||||
expected_exception=InvalidSelectValueError,
|
||||
expected_message="Invalid `select` value for parameter status",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_number_value_string",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
result={"age": "twenty-five"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message="Invalid value for parameter age, expected segment type: number, actual_type: string",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_bool_value_string",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": "yes"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter active, expected segment type: boolean, actual_type: string"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_string_value_number",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="description", type=SegmentType.STRING, description="Description", required=True
|
||||
)
|
||||
],
|
||||
result={"description": 123},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter description, expected segment type: string, actual_type: integer"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_value_not_list",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": "tag1,tag2,tag3"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter tags, expected segment type: array[string], actual_type: string"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_number_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
result={"scores": [85, "ninety-two", 78]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter scores, expected segment type: array[number], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_string_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": ["tag1", 123, "tag3"]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter tags, expected segment type: array[string], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_object_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
result={"items": [{"name": "item1"}, "item2"]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter items, expected segment type: array[object], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="required_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=False),
|
||||
],
|
||||
result={"age": 25, "other": "value"}, # Missing required 'name' parameter, but has correct count
|
||||
expected_exception=RequiredParameterMissingError,
|
||||
expected_message="Parameter name is required",
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_valid_test_cases(), ids=ValidTestCase.get_name)
|
||||
def test_validate_result_valid_cases(self, test_case):
|
||||
"""Test _validate_result with valid inputs."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
result = helper.validate_result(data=node_data, result=test_case.result)
|
||||
assert result == test_case.result, f"Failed for case: {test_case.name}"
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_error_test_cases(), ids=ErrorTestCase.get_name)
|
||||
def test_validate_result_error_cases(self, test_case):
|
||||
"""Test _validate_result with invalid inputs that should raise exceptions."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
with pytest.raises(test_case.expected_exception) as exc_info:
|
||||
helper.validate_result(data=node_data, result=test_case.result)
|
||||
|
||||
assert test_case.expected_message in str(exc_info.value), f"Failed for case: {test_case.name}"
|
||||
|
||||
|
||||
class TestTransformResult:
|
||||
"""Test cases for _transform_result method."""
|
||||
|
||||
@staticmethod
|
||||
def get_transform_test_cases() -> list[TransformTestCase]:
|
||||
"""Get test cases for result transformation."""
|
||||
return [
|
||||
# String parameter transformation
|
||||
TransformTestCase(
|
||||
name="string_parameter_present",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
input_result={"name": "John"},
|
||||
expected_result={"name": "John"},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="string_parameter_missing",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
input_result={},
|
||||
expected_result={"name": ""},
|
||||
),
|
||||
# Number parameter transformation
|
||||
TransformTestCase(
|
||||
name="number_parameter_int_present",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": 25},
|
||||
expected_result={"age": 25},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_float_present",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
input_result={"price": 19.99},
|
||||
expected_result={"price": 19.99},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_missing",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={},
|
||||
expected_result={"age": 0},
|
||||
),
|
||||
# Bool parameter transformation
|
||||
TransformTestCase(
|
||||
name="bool_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"active": False},
|
||||
),
|
||||
# Select parameter transformation
|
||||
TransformTestCase(
|
||||
name="select_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
input_result={"status": "active"},
|
||||
expected_result={"status": "active"},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="select_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"status": ""},
|
||||
),
|
||||
# Array parameter transformation - present cases
|
||||
TransformTestCase(
|
||||
name="array_string_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
input_result={"tags": ["tag1", "tag2"]},
|
||||
expected_result={
|
||||
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=["tag1", "tag2"])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, 92.5]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_with_string_conversion",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, "92.5", "78"]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5, 78])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_object_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
input_result={"items": [{"name": "item1"}, {"name": "item2"}]},
|
||||
expected_result={
|
||||
"items": build_segment_with_type(
|
||||
segment_type=SegmentType.ARRAY_OBJECT, value=[{"name": "item1"}, {"name": "item2"}]
|
||||
)
|
||||
},
|
||||
),
|
||||
# Array parameter transformation - missing cases
|
||||
TransformTestCase(
|
||||
name="array_string_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[])},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[])},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_object_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"items": build_segment_with_type(segment_type=SegmentType.ARRAY_OBJECT, value=[])},
|
||||
),
|
||||
# Multiple parameters transformation
|
||||
TransformTestCase(
|
||||
name="multiple_parameters_mixed",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True),
|
||||
],
|
||||
input_result={"name": "John", "age": 25},
|
||||
expected_result={
|
||||
"name": "John",
|
||||
"age": 25,
|
||||
"active": False,
|
||||
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[]),
|
||||
},
|
||||
),
|
||||
# Number parameter transformation with string conversion
|
||||
TransformTestCase(
|
||||
name="number_parameter_string_to_float",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
input_result={"price": "19.99"},
|
||||
expected_result={"price": 19.99}, # String not converted, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_string_to_int",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": "25"},
|
||||
expected_result={"age": 25}, # String not converted, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_invalid_string",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": "invalid_number"},
|
||||
expected_result={"age": 0}, # Invalid string conversion fails, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_non_string_non_number",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": ["not_a_number"]}, # Non-string, non-number value
|
||||
expected_result={"age": 0}, # Falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_with_invalid_string_conversion",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, "invalid", "78"]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(
|
||||
segment_type=SegmentType.ARRAY_NUMBER, value=[85, 78]
|
||||
) # Invalid string skipped
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_transform_test_cases(), ids=TransformTestCase.get_name)
|
||||
def test_transform_result_cases(self, test_case):
|
||||
"""Test _transform_result with various inputs."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
result = helper.transform_result(data=node_data, result=test_case.input_result)
|
||||
assert result == test_case.expected_result, (
|
||||
f"Failed for case: {test_case.name}. Expected: {test_case.expected_result}, Got: {result}"
|
||||
)
|
||||
@ -2,6 +2,8 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
@ -272,3 +274,220 @@ def test_array_file_contains_file_name():
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def _get_test_conditions() -> list:
|
||||
conditions = [
|
||||
# Test boolean "is" operator
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},
|
||||
# Test boolean "is not" operator
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "bool_false"], "value": "true"},
|
||||
# Test boolean "=" operator
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "bool_true"], "value": "1"},
|
||||
# Test boolean "≠" operator
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "bool_false"], "value": "1"},
|
||||
# Test boolean "not null" operator
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "bool_true"]},
|
||||
# Test boolean array "contains" operator
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "bool_array"], "value": "true"},
|
||||
# Test boolean "in" operator
|
||||
{
|
||||
"comparison_operator": "in",
|
||||
"variable_selector": ["start", "bool_true"],
|
||||
"value": ["true", "false"],
|
||||
},
|
||||
]
|
||||
return [Condition.model_validate(i) for i in conditions]
|
||||
|
||||
|
||||
def _get_condition_test_id(c: Condition):
|
||||
return c.comparison_operator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id)
|
||||
def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
"""Test IfElseNode with boolean conditions using various operators"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
pool.add(["start", "mixed_array"], [True, "false", 1, 0])
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Test",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [condition.model_dump()],
|
||||
}
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def test_execute_if_else_boolean_false_conditions():
|
||||
"""Test IfElseNode with boolean conditions that should evaluate to false"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean False Test",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
# Test boolean "is" operator (should be false)
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "false"},
|
||||
# Test boolean "=" operator (should be false)
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "bool_false"], "value": "1"},
|
||||
# Test boolean "not contains" operator (should be false)
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "bool_array"],
|
||||
"value": "true",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data,
|
||||
},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is False
|
||||
|
||||
|
||||
def test_execute_if_else_boolean_cases_structure():
|
||||
"""Test IfElseNode with boolean conditions using the new cases structure"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Cases Test",
|
||||
"type": "if-else",
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"variable_selector": ["start", "bool_true"],
|
||||
"value": "true",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "is not",
|
||||
"variable_selector": ["start", "bool_false"],
|
||||
"value": "true",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
assert result.outputs["selected_case_id"] == "true"
|
||||
|
||||
@ -11,7 +11,8 @@ from core.workflow.nodes.list_operator.entities import (
|
||||
FilterCondition,
|
||||
Limit,
|
||||
ListOperatorNodeData,
|
||||
OrderBy,
|
||||
Order,
|
||||
OrderByConfig,
|
||||
)
|
||||
from core.workflow.nodes.list_operator.exc import InvalidKeyError
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
|
||||
@ -27,7 +28,7 @@ def list_operator_node():
|
||||
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
|
||||
],
|
||||
),
|
||||
"order_by": OrderBy(enabled=False, value="asc"),
|
||||
"order_by": OrderByConfig(enabled=False, value=Order.ASC),
|
||||
"limit": Limit(enabled=False, size=0),
|
||||
"extract_by": ExtractConfig(enabled=False, serial="1"),
|
||||
"title": "Test Title",
|
||||
|
||||
@ -59,7 +59,7 @@ def mock_response_receiver(monkeypatch) -> mock.Mock:
|
||||
@pytest.fixture
|
||||
def mock_logger(monkeypatch) -> logging.Logger:
|
||||
_logger = mock.MagicMock(spec=logging.Logger)
|
||||
monkeypatch.setattr(ext_request_logging, "_logger", _logger)
|
||||
monkeypatch.setattr(ext_request_logging, "logger", _logger)
|
||||
return _logger
|
||||
|
||||
|
||||
|
||||
@ -24,16 +24,18 @@ from core.variables.segments import (
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from factories import variable_factory
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
@ -139,6 +141,26 @@ def test_array_number_variable():
|
||||
assert isinstance(variable.value[1], float)
|
||||
|
||||
|
||||
def test_build_segment_scalar_values():
|
||||
@dataclass
|
||||
class TestCase:
|
||||
value: Any
|
||||
expected: Segment
|
||||
description: str
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
value=True,
|
||||
expected=BooleanSegment(value=True),
|
||||
description="build_segment with boolean should yield BooleanSegment",
|
||||
)
|
||||
]
|
||||
|
||||
for idx, c in enumerate(cases, 1):
|
||||
seg = build_segment(c.value)
|
||||
assert seg == c.expected, f"Test case {idx} failed: {c.description}"
|
||||
|
||||
|
||||
def test_array_object_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
@ -847,15 +869,22 @@ class TestBuildSegmentValueErrors:
|
||||
f"but got: {error_message}"
|
||||
)
|
||||
|
||||
def test_build_segment_boolean_type_note(self):
|
||||
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
|
||||
# Boolean values in Python are subclasses of int, so they get processed as integers
|
||||
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
|
||||
def test_build_segment_boolean_type(self):
|
||||
"""Test that Boolean values are correctly handled as boolean type, not integers."""
|
||||
# Boolean values should now be processed as BooleanSegment, not IntegerSegment
|
||||
# This is because the bool check now comes before the int check in build_segment
|
||||
true_segment = variable_factory.build_segment(True)
|
||||
false_segment = variable_factory.build_segment(False)
|
||||
|
||||
# Verify they are processed as integers, not as errors
|
||||
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
|
||||
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
|
||||
assert true_segment.value_type == SegmentType.INTEGER
|
||||
assert false_segment.value_type == SegmentType.INTEGER
|
||||
# Verify they are processed as booleans, not integers
|
||||
assert true_segment.value is True, "Test case 1 (boolean_true): Expected True to be processed as boolean True"
|
||||
assert false_segment.value is False, (
|
||||
"Test case 2 (boolean_false): Expected False to be processed as boolean False"
|
||||
)
|
||||
assert true_segment.value_type == SegmentType.BOOLEAN
|
||||
assert false_segment.value_type == SegmentType.BOOLEAN
|
||||
|
||||
# Test array of booleans
|
||||
bool_array_segment = variable_factory.build_segment([True, False, True])
|
||||
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN
|
||||
assert bool_array_segment.value == [True, False, True]
|
||||
|
||||
@ -83,7 +83,7 @@ class TestDatasetPermissionService:
|
||||
@pytest.fixture
|
||||
def mock_logging_dependencies(self):
|
||||
"""Mock setup for logging tests."""
|
||||
with patch("services.dataset_service.logging") as mock_logging:
|
||||
with patch("services.dataset_service.logger") as mock_logging:
|
||||
yield {
|
||||
"logging": mock_logging,
|
||||
}
|
||||
|
||||
@ -179,7 +179,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
delete_draft_variables_batch(app_id, 0)
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
@patch("tasks.remove_app_and_related_data_task.logging")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db):
|
||||
"""Test that batch deletion logs progress correctly."""
|
||||
app_id = "test-app-id"
|
||||
|
||||
Reference in New Issue
Block a user