feat: fix i18n missing keys and merge upstream/main (#24615)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: GuanMu <ballmanjq@gmail.com>
Co-authored-by: Davide Delbianco <davide.delbianco@outlook.com>
Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: Qiang Lee <18018968632@163.com>
Co-authored-by: 李强04 <liqiang04@gaotu.cn>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Matri Qi <matrixdom@126.com>
Co-authored-by: huayaoyue6 <huayaoyue@163.com>
Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
Co-authored-by: znn <jubinkumarsoni@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Muke Wang <shaodwaaron@gmail.com>
Co-authored-by: wangmuke <wangmuke@kingsware.cn>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Eric Guo <eric.guocz@gmail.com>
Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
Co-authored-by: jiangbo721 <jiangbo721@163.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: hjlarry <25834719+hjlarry@users.noreply.github.com>
Co-authored-by: lxsummer <35754229+lxjustdoit@users.noreply.github.com>
Co-authored-by: 湛露先生 <zhanluxianshen@163.com>
Co-authored-by: Guangdong Liu <liugddx@gmail.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Yessenia-d <yessenia.contact@gmail.com>
Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com>
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
Co-authored-by: 17hz <0x149527@gmail.com>
Co-authored-by: Amy <1530140574@qq.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Nite Knite <nkCoding@gmail.com>
Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
Co-authored-by: Petrus Han <petrus.hanks@gmail.com>
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
Co-authored-by: Kalo Chin <frog.beepers.0n@icloud.com>
Co-authored-by: Ujjwal Maurya <ujjwalsbx@gmail.com>
Co-authored-by: Maries <xh001x@hotmail.com>
This commit is contained in:
lyzno1
2025-08-27 15:07:28 +08:00
committed by GitHub
parent a63d1e87b1
commit 5bbf685035
625 changed files with 23778 additions and 10693 deletions

View File

@ -90,6 +90,7 @@ def test_flask_configs(monkeypatch):
"pool_recycle": 3600,
"pool_size": 30,
"pool_use_lifo": False,
"pool_reset_on_return": None,
}
assert config["CONSOLE_WEB_URL"] == "https://example.com"

View File

@ -1,9 +1,8 @@
import datetime
import uuid
from collections import OrderedDict
from typing import Any, NamedTuple
from flask_restful import marshal
from flask_restx import marshal
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
@ -13,6 +12,7 @@ from controllers.console.app.workflow_draft_variable import (
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from libs.datetime_utils import naive_utc_now
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList
@ -57,7 +57,7 @@ class TestWorkflowDraftVariableFields:
)
sys_var.id = str(uuid.uuid4())
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
sys_var.last_edited_at = naive_utc_now()
sys_var.visible = True
expected_without_value = OrderedDict(
@ -88,7 +88,7 @@ class TestWorkflowDraftVariableFields:
)
node_var.id = str(uuid.uuid4())
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
node_var.last_edited_at = naive_utc_now()
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{

View File

@ -0,0 +1,134 @@
"""Test authentication security to prevent user enumeration."""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
import services.errors.account
from controllers.console.auth.error import AuthenticationFailedError
from controllers.console.auth.login import LoginApi
from controllers.console.error import AccountNotFound
class TestAuthenticationSecurity:
"""Test authentication endpoints for security against user enumeration."""
def setup_method(self):
"""Set up test fixtures."""
self.app = Flask(__name__)
self.api = Api(self.app)
self.api.add_resource(LoginApi, "/login")
self.client = self.app.test_client()
self.app.config["TESTING"] = True
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_invalid_email_with_registration_allowed(
self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
"""Test that invalid email sends reset password email when registration is allowed."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
mock_send_email.return_value = "token123"
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
result = login_api.post()
# Assert
assert result == {"result": "fail", "data": "token123", "code": "account_not_found"}
mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_wrong_password_returns_error(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
):
"""Test that wrong password returns AuthenticationFailedError."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
# Assert
with pytest.raises(AuthenticationFailedError) as exc_info:
login_api.post()
assert exc_info.value.error_code == "authentication_failed"
assert exc_info.value.description == "Invalid email or password."
mock_add_rate_limit.assert_called_once_with("existing@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_invalid_email_with_registration_disabled(
self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
"""Test that invalid email raises AccountNotFound when registration is disabled."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
# Assert
with pytest.raises(AccountNotFound) as exc_info:
login_api.post()
assert exc_info.value.error_code == "account_not_found"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
"""Test that reset password returns success with token for existing accounts."""
# Mock the setup check
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Test with existing account
mock_get_user.return_value = MagicMock(email="existing@example.com")
mock_send_email.return_value = "token123"
with self.app.test_request_context("/reset-password", method="POST", json={"email": "existing@example.com"}):
from controllers.console.auth.login import ResetPasswordSendEmailApi
api = ResetPasswordSendEmailApi()
result = api.post()
assert result == {"result": "success", "data": "token123"}

View File

@ -0,0 +1,148 @@
"""Tests for LLMUsage entity."""
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
class TestLLMUsage:
"""Test cases for LLMUsage class."""
def test_from_metadata_with_all_tokens(self):
"""Test from_metadata when all token types are provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": 0.001,
"completion_unit_price": 0.002,
"total_price": 0.2,
"currency": "USD",
"latency": 1.5,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.total_price == Decimal("0.2")
assert usage.currency == "USD"
assert usage.latency == 1.5
def test_from_metadata_with_prompt_tokens_only(self):
"""Test from_metadata when only prompt_tokens is provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"total_tokens": 100,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 0
assert usage.total_tokens == 100
def test_from_metadata_with_completion_tokens_only(self):
"""Test from_metadata when only completion_tokens is provided."""
metadata: LLMUsageMetadata = {
"completion_tokens": 50,
"total_tokens": 50,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 50
assert usage.total_tokens == 50
def test_from_metadata_calculates_total_when_missing(self):
"""Test from_metadata calculates total_tokens when not provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150 # Should be calculated
def test_from_metadata_with_total_but_no_completion(self):
"""
Test from_metadata when total_tokens is provided but completion_tokens is 0.
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 479,
"completion_tokens": 0,
"total_tokens": 521,
}
usage = LLMUsage.from_metadata(metadata)
# This is the key fix - prompt tokens should remain as prompt tokens
assert usage.prompt_tokens == 479
assert usage.completion_tokens == 0
assert usage.total_tokens == 521
def test_from_metadata_with_empty_metadata(self):
"""Test from_metadata with empty metadata."""
metadata: LLMUsageMetadata = {}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 0
assert usage.total_tokens == 0
assert usage.currency == "USD"
assert usage.latency == 0.0
def test_from_metadata_preserves_zero_completion_tokens(self):
"""
Test that zero completion_tokens are preserved when explicitly set.
This is important for agent nodes that only use prompt tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 1000,
"completion_tokens": 0,
"total_tokens": 1000,
"prompt_unit_price": 0.15,
"completion_unit_price": 0.60,
"prompt_price": 0.00015,
"completion_price": 0,
"total_price": 0.00015,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 1000
assert usage.completion_tokens == 0
assert usage.total_tokens == 1000
assert usage.prompt_price == Decimal("0.00015")
assert usage.completion_price == Decimal(0)
assert usage.total_price == Decimal("0.00015")
def test_from_metadata_with_decimal_values(self):
"""Test from_metadata handles decimal values correctly."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": "0.001",
"completion_unit_price": "0.002",
"prompt_price": "0.1",
"completion_price": "0.1",
"total_price": "0.2",
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.prompt_price == Decimal("0.1")
assert usage.completion_price == Decimal("0.1")
assert usage.total_price == Decimal("0.2")

View File

@ -0,0 +1,460 @@
from collections.abc import Generator
import pytest
from core.agent.entities import AgentInvokeMessage
from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks
from core.tools.entities.tool_entities import ToolInvokeMessage
class TestChunkMerger:
def test_file_chunk_initialization(self):
"""Test FileChunk initialization."""
chunk = FileChunk(1024)
assert chunk.bytes_written == 0
assert chunk.total_length == 1024
assert len(chunk.data) == 1024
def test_merge_blob_chunks_with_single_complete_chunk(self):
"""Test merging a single complete blob chunk."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# First chunk (partial)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
),
)
# Second chunk (final)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=10, blob=b"World", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
# The buffer should contain the complete data
assert result[0].message.blob[:10] == b"HelloWorld"
def test_merge_blob_chunks_with_multiple_files(self):
"""Test merging chunks from multiple files."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# File 1, chunk 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=4, blob=b"AB", end=False
),
)
# File 2, chunk 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file2", sequence=0, total_length=4, blob=b"12", end=False
),
)
# File 1, chunk 2 (final)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=4, blob=b"CD", end=True
),
)
# File 2, chunk 2 (final)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file2", sequence=1, total_length=4, blob=b"34", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 2
# Check that both files are properly merged
assert all(r.type == ToolInvokeMessage.MessageType.BLOB for r in result)
def test_merge_blob_chunks_passes_through_non_blob_messages(self):
"""Test that non-blob messages pass through unchanged."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Text message
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text="Hello"),
)
# Blob chunk
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=5, blob=b"Test", end=True
),
)
# Another text message
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text="World"),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 3
assert result[0].type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(result[0].message, ToolInvokeMessage.TextMessage)
assert result[0].message.text == "Hello"
assert result[1].type == ToolInvokeMessage.MessageType.BLOB
assert result[2].type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(result[2].message, ToolInvokeMessage.TextMessage)
assert result[2].message.text == "World"
def test_merge_blob_chunks_file_too_large(self):
"""Test that error is raised when file exceeds max size."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Send a chunk that would exceed the limit
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=100, blob=b"x" * 1024, end=False
),
)
with pytest.raises(ValueError) as exc_info:
list(merge_blob_chunks(mock_generator(), max_file_size=1000))
assert "File is too large" in str(exc_info.value)
def test_merge_blob_chunks_chunk_too_large(self):
"""Test that error is raised when chunk exceeds max chunk size."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Send a chunk that exceeds the max chunk size
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=10000, blob=b"x" * 9000, end=False
),
)
with pytest.raises(ValueError) as exc_info:
list(merge_blob_chunks(mock_generator(), max_chunk_size=8192))
assert "File chunk is too large" in str(exc_info.value)
def test_merge_blob_chunks_with_agent_invoke_message(self):
"""Test that merge_blob_chunks works with AgentInvokeMessage."""
def mock_generator() -> Generator[AgentInvokeMessage, None, None]:
# First chunk
yield AgentInvokeMessage(
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
message=AgentInvokeMessage.BlobChunkMessage(
id="agent_file", sequence=0, total_length=8, blob=b"Agent", end=False
),
)
# Final chunk
yield AgentInvokeMessage(
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
message=AgentInvokeMessage.BlobChunkMessage(
id="agent_file", sequence=1, total_length=8, blob=b"Data", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0], AgentInvokeMessage)
assert result[0].type == AgentInvokeMessage.MessageType.BLOB
def test_merge_blob_chunks_preserves_meta(self):
"""Test that meta information is preserved in merged messages."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=4, blob=b"Test", end=True
),
meta={"key": "value"},
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].meta == {"key": "value"}
def test_merge_blob_chunks_custom_limits(self):
"""Test merge_blob_chunks with custom size limits."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# This should work with custom limits
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
),
)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=500, blob=b"y" * 100, end=True
),
)
# Should work with custom limits
result = list(merge_blob_chunks(mock_generator(), max_file_size=1000, max_chunk_size=500))
assert len(result) == 1
# Should fail with smaller file size limit
def mock_generator2() -> Generator[ToolInvokeMessage, None, None]:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
),
)
with pytest.raises(ValueError):
list(merge_blob_chunks(mock_generator2(), max_file_size=300))
def test_merge_blob_chunks_data_integrity(self):
"""Test that merged chunks exactly match the original data."""
# Create original data
original_data = b"This is a test message that will be split into chunks for testing purposes."
chunk_size = 20
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Split original data into chunks
chunks = []
for i in range(0, len(original_data), chunk_size):
chunk_data = original_data[i : i + chunk_size]
is_last = (i + chunk_size) >= len(original_data)
chunks.append((i // chunk_size, chunk_data, is_last))
# Yield chunks
for sequence, data, is_end in chunks:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="test_file",
sequence=sequence,
total_length=len(original_data),
blob=data,
end=is_end,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
# Verify the merged data exactly matches the original
assert result[0].message.blob == original_data
def test_merge_blob_chunks_empty_chunk(self):
"""Test handling of empty chunks."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# First chunk with data
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
),
)
# Empty chunk in the middle
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=10, blob=b"", end=False
),
)
# Final chunk with data
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=2, total_length=10, blob=b"World", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
# The final blob should contain "Hello" followed by "World"
assert result[0].message.blob[:10] == b"HelloWorld"
def test_merge_blob_chunks_single_chunk_file(self):
"""Test file that arrives as a single complete chunk."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Single chunk that is both first and last
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="single_chunk_file",
sequence=0,
total_length=11,
blob=b"Single Data",
end=True,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert result[0].message.blob == b"Single Data"
def test_merge_blob_chunks_concurrent_files(self):
"""Test that chunks from different files are properly separated."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Interleave chunks from three different files
files_data = {
"file1": b"First file content",
"file2": b"Second file data",
"file3": b"Third file",
}
# First chunk from each file
for file_id, data in files_data.items():
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id=file_id,
sequence=0,
total_length=len(data),
blob=data[:6],
end=False,
),
)
# Second chunk from each file (final)
for file_id, data in files_data.items():
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id=file_id,
sequence=1,
total_length=len(data),
blob=data[6:],
end=True,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 3
# Extract the blob data from results
blobs = set()
for r in result:
assert isinstance(r.message, ToolInvokeMessage.BlobMessage)
blobs.add(r.message.blob)
expected = {b"First file content", b"Second file data", b"Third file"}
assert blobs == expected
def test_merge_blob_chunks_exact_buffer_size(self):
"""Test that data fitting exactly in buffer works correctly."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Create data that exactly fills the declared buffer
exact_data = b"X" * 100
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="exact_file",
sequence=0,
total_length=100,
blob=exact_data[:50],
end=False,
),
)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="exact_file",
sequence=1,
total_length=100,
blob=exact_data[50:],
end=True,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert len(result[0].message.blob) == 100
assert result[0].message.blob == b"X" * 100
def test_merge_blob_chunks_large_file_simulation(self):
"""Test handling of a large file split into many chunks."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Simulate a 1MB file split into 128 chunks of 8KB each
chunk_size = 8192
num_chunks = 128
total_size = chunk_size * num_chunks
for i in range(num_chunks):
# Create unique data for each chunk to verify ordering
chunk_data = bytes([i % 256]) * chunk_size
is_last = i == num_chunks - 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="large_file",
sequence=i,
total_length=total_size,
blob=chunk_data,
end=is_last,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert len(result[0].message.blob) == 1024 * 1024
# Verify the data pattern is correct
merged_data = result[0].message.blob
chunk_size = 8192
num_chunks = 128
for i in range(num_chunks):
chunk_start = i * chunk_size
chunk_end = chunk_start + chunk_size
expected_byte = i % 256
chunk = merged_data[chunk_start:chunk_end]
assert all(b == expected_byte for b in chunk), f"Chunk {i} has incorrect data"
def test_merge_blob_chunks_sequential_order_required(self):
"""
Test note: The current implementation assumes chunks arrive in sequential order.
Out-of-order chunks would need additional logic to handle properly.
This test documents the expected behavior with sequential chunks.
"""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Chunks arriving in correct sequential order
data_parts = [b"First", b"Second", b"Third"]
total_length = sum(len(part) for part in data_parts)
for i, part in enumerate(data_parts):
is_last = i == len(data_parts) - 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="ordered_file",
sequence=i,
total_length=total_length,
blob=part,
end=is_last,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert result[0].message.blob == b"FirstSecondThird"

View File

@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
@ -13,6 +12,7 @@ import pytest
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@ -56,7 +56,7 @@ def sample_workflow_execution():
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
@ -199,7 +199,7 @@ class TestCeleryWorkflowExecutionRepository:
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
exec2 = WorkflowExecution.new(
id_=str(uuid4()),
@ -208,7 +208,7 @@ class TestCeleryWorkflowExecutionRepository:
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input2": "value2"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
# Save both executions
@ -235,7 +235,7 @@ class TestCeleryWorkflowExecutionRepository:
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
repo.save(execution)

View File

@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
@ -18,6 +17,7 @@ from core.workflow.entities.workflow_node_execution import (
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@ -65,7 +65,7 @@ def sample_workflow_node_execution():
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
)
@ -263,7 +263,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 1",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
@ -276,7 +276,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 2",
inputs={"input2": "value2"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
)
# Save both executions
@ -314,7 +314,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 2",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
@ -327,7 +327,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 1",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
)
# Save in random order

View File

@ -2,19 +2,19 @@
Unit tests for the RepositoryFactory.
This module tests the factory pattern implementation for creating repository instances
based on configuration, including error handling and validation.
based on configuration, including error handling.
"""
from unittest.mock import MagicMock, patch
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from libs.module_loading import import_string
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@ -23,98 +23,30 @@ from models.workflow import WorkflowNodeExecutionTriggeredFrom
class TestRepositoryFactory:
"""Test cases for RepositoryFactory."""
def test_import_class_success(self):
def test_import_string_success(self):
"""Test successful class import."""
# Test importing a real class
class_path = "unittest.mock.MagicMock"
result = DifyCoreRepositoryFactory._import_class(class_path)
result = import_string(class_path)
assert result is MagicMock
def test_import_class_invalid_path(self):
def test_import_string_invalid_path(self):
"""Test import with invalid module path."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalid.module.path")
assert "Cannot import repository class" in str(exc_info.value)
with pytest.raises(ImportError) as exc_info:
import_string("invalid.module.path")
assert "No module named" in str(exc_info.value)
def test_import_class_invalid_class_name(self):
def test_import_string_invalid_class_name(self):
"""Test import with invalid class name."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
assert "Cannot import repository class" in str(exc_info.value)
with pytest.raises(ImportError) as exc_info:
import_string("unittest.mock.NonExistentClass")
assert "does not define" in str(exc_info.value)
def test_import_class_malformed_path(self):
def test_import_string_malformed_path(self):
"""Test import with malformed path (no dots)."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalidpath")
assert "Cannot import repository class" in str(exc_info.value)
def test_validate_repository_interface_success(self):
"""Test successful interface validation."""
# Create a mock class that implements the required methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
# Create a mock interface class
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
# Should not raise an exception when all methods are present
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods."""
# Create a mock class that's missing required methods
class IncompleteRepository:
def save(self):
pass
# Missing get_by_id method
# Create a mock interface that requires both methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
def missing_method(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value)
def test_validate_repository_interface_with_private_methods(self):
"""Test that private methods are ignored during interface validation."""
class MockRepository:
def save(self):
pass
def _private_method(self):
pass
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
def _private_method(self):
pass
# Should not raise exception - private methods should be ignored
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
with pytest.raises(ImportError) as exc_info:
import_string("invalidpath")
assert "doesn't look like a module path" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config):
@ -133,11 +65,8 @@ class TestRepositoryFactory:
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
@ -170,34 +99,7 @@ class TestRepositoryFactory:
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Cannot import repository class" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
@ -212,11 +114,8 @@ class TestRepositoryFactory:
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
# Mock import_string to return a failing class
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
@ -243,11 +142,8 @@ class TestRepositoryFactory:
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
@ -280,34 +176,7 @@ class TestRepositoryFactory:
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Cannot import repository class" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Interface validation failed" in str(exc_info.value)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
@ -322,11 +191,8 @@ class TestRepositoryFactory:
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
# Mock import_string to return a failing class
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
@ -359,11 +225,8 @@ class TestRepositoryFactory:
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,

View File

@ -0,0 +1,308 @@
from unittest.mock import Mock, patch
import pytest
from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus
from core.entities.provider_entities import (
CustomConfiguration,
ModelSettings,
ProviderQuotaType,
QuotaConfiguration,
QuotaUnit,
RestrictModel,
SystemConfiguration,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from models.provider import Provider, ProviderType
@pytest.fixture
def mock_provider_entity():
"""Mock provider entity with basic configuration"""
provider_entity = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"),
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
background="background.png",
help=None,
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
provider_credential_schema=None,
model_credential_schema=None,
)
return provider_entity
@pytest.fixture
def mock_system_configuration():
"""Mock system configuration"""
quota_config = QuotaConfiguration(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=1000,
quota_used=0,
is_valid=True,
restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)],
)
system_config = SystemConfiguration(
enabled=True,
credentials={"openai_api_key": "test_key"},
quota_configurations=[quota_config],
current_quota_type=ProviderQuotaType.TRIAL,
)
return system_config
@pytest.fixture
def mock_custom_configuration():
"""Mock custom configuration"""
custom_config = CustomConfiguration(provider=None, models=[])
return custom_config
@pytest.fixture
def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration):
"""Create a test provider configuration instance"""
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
return ProviderConfiguration(
tenant_id="test_tenant",
provider=mock_provider_entity,
preferred_provider_type=ProviderType.SYSTEM,
using_provider_type=ProviderType.SYSTEM,
system_configuration=mock_system_configuration,
custom_configuration=mock_custom_configuration,
model_settings=[],
)
class TestProviderConfiguration:
"""Test cases for ProviderConfiguration class"""
def test_get_current_credentials_system_provider_success(self, provider_configuration):
"""Test successfully getting credentials from system provider"""
# Arrange
provider_configuration.using_provider_type = ProviderType.SYSTEM
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
# Assert
assert credentials == {"openai_api_key": "test_key"}
def test_get_current_credentials_model_disabled(self, provider_configuration):
"""Test getting credentials when model is disabled"""
# Arrange
model_setting = ModelSettings(
model="gpt-4",
model_type=ModelType.LLM,
enabled=False,
load_balancing_configs=[],
has_invalid_load_balancing_configs=False,
)
provider_configuration.model_settings = [model_setting]
# Act & Assert
with pytest.raises(ValueError, match="Model gpt-4 is disabled"):
provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
def test_get_current_credentials_custom_provider_with_models(self, provider_configuration):
"""Test getting credentials from custom provider with model configurations"""
# Arrange
provider_configuration.using_provider_type = ProviderType.CUSTOM
mock_model_config = Mock()
mock_model_config.model_type = ModelType.LLM
mock_model_config.model = "gpt-4"
mock_model_config.credentials = {"openai_api_key": "custom_key"}
provider_configuration.custom_configuration.models = [mock_model_config]
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
# Assert
assert credentials == {"openai_api_key": "custom_key"}
def test_get_system_configuration_status_active(self, provider_configuration):
"""Test getting active system configuration status"""
# Arrange
provider_configuration.system_configuration.enabled = True
# Act
status = provider_configuration.get_system_configuration_status()
# Assert
assert status == SystemConfigurationStatus.ACTIVE
def test_get_system_configuration_status_unsupported(self, provider_configuration):
"""Test getting unsupported system configuration status"""
# Arrange
provider_configuration.system_configuration.enabled = False
# Act
status = provider_configuration.get_system_configuration_status()
# Assert
assert status == SystemConfigurationStatus.UNSUPPORTED
def test_get_system_configuration_status_quota_exceeded(self, provider_configuration):
"""Test getting quota exceeded system configuration status"""
# Arrange
provider_configuration.system_configuration.enabled = True
quota_config = provider_configuration.system_configuration.quota_configurations[0]
quota_config.is_valid = False
# Act
status = provider_configuration.get_system_configuration_status()
# Assert
assert status == SystemConfigurationStatus.QUOTA_EXCEEDED
def test_is_custom_configuration_available_with_provider(self, provider_configuration):
"""Test custom configuration availability with provider credentials"""
# Arrange
mock_provider = Mock()
mock_provider.available_credentials = ["openai_api_key"]
provider_configuration.custom_configuration.provider = mock_provider
provider_configuration.custom_configuration.models = []
# Act
result = provider_configuration.is_custom_configuration_available()
# Assert
assert result is True
def test_is_custom_configuration_available_with_models(self, provider_configuration):
"""Test custom configuration availability with model configurations"""
# Arrange
provider_configuration.custom_configuration.provider = None
provider_configuration.custom_configuration.models = [Mock()]
# Act
result = provider_configuration.is_custom_configuration_available()
# Assert
assert result is True
def test_is_custom_configuration_available_false(self, provider_configuration):
"""Test custom configuration not available"""
# Arrange
provider_configuration.custom_configuration.provider = None
provider_configuration.custom_configuration.models = []
# Act
result = provider_configuration.is_custom_configuration_available()
# Assert
assert result is False
@patch("core.entities.provider_configuration.Session")
def test_get_provider_record_found(self, mock_session, provider_configuration):
"""Test getting provider record successfully"""
# Arrange
mock_provider = Mock(spec=Provider)
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider
# Act
result = provider_configuration._get_provider_record(mock_session_instance)
# Assert
assert result == mock_provider
@patch("core.entities.provider_configuration.Session")
def test_get_provider_record_not_found(self, mock_session, provider_configuration):
"""Test getting provider record when not found"""
# Arrange
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
# Act
result = provider_configuration._get_provider_record(mock_session_instance)
# Assert
assert result is None
def test_init_with_customizable_model_only(
self, mock_provider_entity, mock_system_configuration, mock_custom_configuration
):
"""Test initialization with customizable model only configuration"""
# Arrange
mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL]
# Act
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
config = ProviderConfiguration(
tenant_id="test_tenant",
provider=mock_provider_entity,
preferred_provider_type=ProviderType.SYSTEM,
using_provider_type=ProviderType.SYSTEM,
system_configuration=mock_system_configuration,
custom_configuration=mock_custom_configuration,
model_settings=[],
)
# Assert
assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods
def test_get_current_credentials_with_restricted_models(self, provider_configuration):
"""Test getting credentials with model restrictions"""
# Arrange
provider_configuration.using_provider_type = ProviderType.SYSTEM
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo")
# Assert
assert credentials is not None
assert "openai_api_key" in credentials
@patch("core.entities.provider_configuration.Session")
def test_get_specific_provider_credential_success(self, mock_session, provider_configuration):
"""Test getting specific provider credential successfully"""
# Arrange
credential_id = "test_credential_id"
mock_credential = Mock()
mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}'
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential
# Act
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
mock_get.return_value = {"openai_api_key": "test_key"}
result = provider_configuration._get_specific_provider_credential(credential_id)
# Assert
assert result == {"openai_api_key": "test_key"}
@patch("core.entities.provider_configuration.Session")
def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration):
"""Test getting specific provider credential when not found"""
# Arrange
credential_id = "nonexistent_credential_id"
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
# Act & Assert
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
mock_get.return_value = None
result = provider_configuration._get_specific_provider_credential(credential_id)
assert result is None
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
# Assert
assert credentials == {"openai_api_key": "test_key"}

View File

@ -1,190 +1,185 @@
# from core.entities.provider_entities import ModelSettings
# from core.model_runtime.entities.model_entities import ModelType
# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
# from core.provider_manager import ProviderManager
# from models.provider import LoadBalancingModelConfig, ProviderModelSetting
import pytest
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
# def test__to_model_settings(mocker):
# # Get all provider entities
# model_provider_factory = ModelProviderFactory("test_tenant")
# provider_entities = model_provider_factory.get_providers()
@pytest.fixture
def mock_provider_entity(mocker):
mock_entity = mocker.Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
# provider_entity = None
# for provider in provider_entities:
# if provider.provider == "openai":
# provider_entity = provider
mock_entity.model_credential_schema = mocker.Mock()
mock_entity.model_credential_schema.credential_form_schemas = []
# # Mocking the inputs
# provider_model_settings = [
# ProviderModelSetting(
# id="id",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# enabled=True,
# load_balancing_enabled=True,
# )
# ]
# load_balancing_model_configs = [
# LoadBalancingModelConfig(
# id="id1",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="__inherit__",
# encrypted_config=None,
# enabled=True,
# ),
# LoadBalancingModelConfig(
# id="id2",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="first",
# encrypted_config='{"openai_api_key": "fake_key"}',
# enabled=True,
# ),
# ]
# mocker.patch(
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
# )
# provider_manager = ProviderManager()
# # Running the method
# result = provider_manager._to_model_settings(provider_entity,
# provider_model_settings, load_balancing_model_configs)
# # Asserting that the result is as expected
# assert len(result) == 1
# assert isinstance(result[0], ModelSettings)
# assert result[0].model == "gpt-4"
# assert result[0].model_type == ModelType.LLM
# assert result[0].enabled is True
# assert len(result[0].load_balancing_configs) == 2
# assert result[0].load_balancing_configs[0].name == "__inherit__"
# assert result[0].load_balancing_configs[1].name == "first"
return mock_entity
# def test__to_model_settings_only_one_lb(mocker):
# # Get all provider entities
# model_provider_factory = ModelProviderFactory("test_tenant")
# provider_entities = model_provider_factory.get_providers()
def test__to_model_settings(mocker, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
id="id2",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
# provider_entity = None
# for provider in provider_entities:
# if provider.provider == "openai":
# provider_entity = provider
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
# # Mocking the inputs
# provider_model_settings = [
# ProviderModelSetting(
# id="id",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# enabled=True,
# load_balancing_enabled=True,
# )
# ]
# load_balancing_model_configs = [
# LoadBalancingModelConfig(
# id="id1",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="__inherit__",
# encrypted_config=None,
# enabled=True,
# )
# ]
provider_manager = ProviderManager()
# mocker.patch(
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
# )
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# provider_manager = ProviderManager()
# # Running the method
# result = provider_manager._to_model_settings(
# provider_entity, provider_model_settings, load_balancing_model_configs)
# # Asserting that the result is as expected
# assert len(result) == 1
# assert isinstance(result[0], ModelSettings)
# assert result[0].model == "gpt-4"
# assert result[0].model_type == ModelType.LLM
# assert result[0].enabled is True
# assert len(result[0].load_balancing_configs) == 0
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 2
assert result[0].load_balancing_configs[0].name == "__inherit__"
assert result[0].load_balancing_configs[1].name == "first"
# def test__to_model_settings_lb_disabled(mocker):
# # Get all provider entities
# model_provider_factory = ModelProviderFactory("test_tenant")
# provider_entities = model_provider_factory.get_providers()
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
)
]
# provider_entity = None
# for provider in provider_entities:
# if provider.provider == "openai":
# provider_entity = provider
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
# # Mocking the inputs
# provider_model_settings = [
# ProviderModelSetting(
# id="id",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# enabled=True,
# load_balancing_enabled=False,
# )
# ]
# load_balancing_model_configs = [
# LoadBalancingModelConfig(
# id="id1",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="__inherit__",
# encrypted_config=None,
# enabled=True,
# ),
# LoadBalancingModelConfig(
# id="id2",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="first",
# encrypted_config='{"openai_api_key": "fake_key"}',
# enabled=True,
# ),
# ]
provider_manager = ProviderManager()
# mocker.patch(
# "core.helper.model_provider_cache.ProviderCredentialsCache.get",
# return_value={"openai_api_key": "fake_key"}
# )
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# provider_manager = ProviderManager()
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0
# # Running the method
# result = provider_manager._to_model_settings(provider_entity,
# provider_model_settings, load_balancing_model_configs)
# # Asserting that the result is as expected
# assert len(result) == 1
# assert isinstance(result[0], ModelSettings)
# assert result[0].model == "gpt-4"
# assert result[0].model_type == ModelType.LLM
# assert result[0].enabled is True
# assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=False,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
id="id2",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0

View File

@ -23,6 +23,7 @@ class TestSegmentTypeIsArrayType:
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
expected_non_array_types = [
SegmentType.INTEGER,
@ -34,6 +35,7 @@ class TestSegmentTypeIsArrayType:
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
SegmentType.BOOLEAN,
]
for seg_type in expected_array_types:

View File

@ -0,0 +1,729 @@
"""
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
This module provides thorough testing of the validation logic for all SegmentType values,
including edge cases, error conditions, and different ArrayValidation strategies.
"""
from dataclasses import dataclass
from typing import Any
import pytest
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables.types import ArrayValidation, SegmentType
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
@dataclass
class ValidationTestCase:
"""Test case data structure for validation tests."""
segment_type: SegmentType
value: Any
expected: bool
description: str
def get_id(self):
return self.description
@dataclass
class ArrayValidationTestCase:
"""Test case data structure for array validation tests."""
segment_type: SegmentType
value: Any
array_validation: ArrayValidation
expected: bool
description: str
def get_id(self):
return self.description
# Test data construction functions
def get_boolean_cases() -> list[ValidationTestCase]:
return [
# valid values
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
# Invalid values
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
]
def get_number_cases() -> list[ValidationTestCase]:
"""Get test cases for valid number values."""
return [
# valid values
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
# invalid number values
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
]
def get_string_cases() -> list[ValidationTestCase]:
"""Get test cases for valid string values."""
return [
# valid values
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
# invalid values
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
]
def get_object_cases() -> list[ValidationTestCase]:
"""Get test cases for valid object values."""
return [
# valid cases
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
# invalid cases
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
]
def get_secret_cases() -> list[ValidationTestCase]:
"""Get test cases for valid secret values."""
return [
# valid cases
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
# invalid cases
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
]
def get_file_cases() -> list[ValidationTestCase]:
"""Get test cases for valid file values."""
test_file = create_test_file()
image_file = create_test_file(
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
)
remote_file = create_test_file(
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
)
return [
# valid cases
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
# invalid cases
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
]
def get_none_cases() -> list[ValidationTestCase]:
"""Get test cases for valid none values."""
return [
# valid cases
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
# invalid cases
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
]
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_ANY validation."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.FIRST,
True,
"Mixed types with FIRST validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.ALL,
True,
"Mixed types with ALL validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
),
]
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", "world"],
ArrayValidation.NONE,
True,
"Valid strings with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, 456],
ArrayValidation.NONE,
True,
"Invalid elements with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["valid", 123, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
]
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", 123, True],
ArrayValidation.FIRST,
True,
"First valid, others invalid",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, "hello", "world"],
ArrayValidation.FIRST,
False,
"First invalid, others valid",
),
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
]
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
),
]
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER,
[float("inf"), float("-inf"), float("nan")],
ArrayValidation.ALL,
True,
"Special float values",
),
]
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "not an object"],
ArrayValidation.FIRST,
True,
"First valid object",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
["not an object", {"valid": "object"}],
ArrayValidation.FIRST,
False,
"First invalid",
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{}, {"a": 1}, {"nested": {"key": "value"}}],
ArrayValidation.ALL,
True,
"All valid objects",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "invalid", {"another": "object"}],
ArrayValidation.ALL,
False,
"One invalid element",
),
]
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_FILE validation with different strategies."""
file1 = create_test_file(filename="file1.txt")
file2 = create_test_file(filename="file2.txt")
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
),
# ALL strategy
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
),
]
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN,
[True, "false", False],
ArrayValidation.ALL,
False,
"One invalid element (string)",
),
]
class TestSegmentTypeIsValid:
"""Test suite for SegmentType.is_valid method covering all non-array types."""
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
def test_boolean_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
def test_number_validation(self, case: ValidationTestCase):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
def test_string_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
def test_object_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
def test_secret_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
def test_file_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
def test_none_validation_valid_cases(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
def test_unsupported_segment_type_raises_assertion_error(self):
"""Test that unsupported SegmentType values raise AssertionError."""
# GROUP is not handled in is_valid method
with pytest.raises(AssertionError, match="this statement should be unreachable"):
SegmentType.GROUP.is_valid("any value")
class TestSegmentTypeArrayValidation:
"""Test suite for SegmentType._validate_array method and array type validation."""
def test_array_validation_non_list_values(self):
"""Test that non-list values return False for all array types."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
non_list_values = [
"not a list",
123,
3.14,
True,
None,
{"key": "value"},
create_test_file(),
]
for array_type in array_types:
for value in non_list_values:
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
def test_empty_array_validation(self):
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
for array_type in array_types:
for strategy in validation_strategies:
assert array_type.is_valid([], strategy) is True, (
f"{array_type} should accept empty array with {strategy}"
)
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
def test_array_any_validation(self, case):
"""Test ARRAY_ANY validation accepts any list regardless of content."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_none_strategy(self, case):
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_first_strategy(self, case):
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_all_strategy(self, case):
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
def test_array_number_validation_with_different_strategies(self, case):
"""Test ARRAY_NUMBER validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
def test_array_object_validation_with_different_strategies(self, case):
"""Test ARRAY_OBJECT validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
def test_array_file_validation_with_different_strategies(self, case):
"""Test ARRAY_FILE validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
def test_array_boolean_validation_with_different_strategies(self, case):
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
def test_default_array_validation_strategy(self):
"""Test that default array validation strategy is FIRST."""
# When no array_validation parameter is provided, it should default to FIRST
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
def test_array_validation_edge_cases(self):
"""Test edge cases for array validation."""
# Test with nested arrays (should be invalid for specific array types)
nested_array = [["nested", "array"], ["another", "nested"]]
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
# Test with very large arrays (performance consideration)
large_valid_array = ["string"] * 1000
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
class TestSegmentTypeValidationIntegration:
"""Integration tests for SegmentType validation covering interactions between methods."""
def test_non_array_types_ignore_array_validation_parameter(self):
"""Test that non-array types ignore the array_validation parameter."""
non_array_types = [
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
]
for segment_type in non_array_types:
# Create appropriate valid value for each type
valid_value: Any
if segment_type == SegmentType.STRING:
valid_value = "test"
elif segment_type == SegmentType.NUMBER:
valid_value = 42
elif segment_type == SegmentType.BOOLEAN:
valid_value = True
elif segment_type == SegmentType.OBJECT:
valid_value = {"key": "value"}
elif segment_type == SegmentType.SECRET:
valid_value = "secret"
elif segment_type == SegmentType.FILE:
valid_value = create_test_file()
elif segment_type == SegmentType.NONE:
valid_value = None
else:
continue # Skip unsupported types
# All array validation strategies should give the same result
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
assert result_none == result_first == result_all == True, (
f"{segment_type} should ignore array_validation parameter"
)
def test_comprehensive_type_coverage(self):
"""Test that all SegmentType enum values are covered in validation tests."""
all_segment_types = set(SegmentType)
# Types that should be handled by is_valid method
handled_types = {
# Non-array types
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
# Array types
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
}
# Types that are not handled by is_valid (should raise AssertionError)
unhandled_types = {
SegmentType.GROUP,
SegmentType.INTEGER, # Handled by NUMBER validation logic
SegmentType.FLOAT, # Handled by NUMBER validation logic
}
# Verify all types are accounted for
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
# Test that handled types work correctly
for segment_type in handled_types:
if segment_type.is_array_type():
# Test with empty array (should always be valid)
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
else:
# Test with appropriate valid value
if segment_type == SegmentType.STRING:
assert segment_type.is_valid("test") is True
elif segment_type == SegmentType.NUMBER:
assert segment_type.is_valid(42) is True
elif segment_type == SegmentType.BOOLEAN:
assert segment_type.is_valid(True) is True
elif segment_type == SegmentType.OBJECT:
assert segment_type.is_valid({}) is True
elif segment_type == SegmentType.SECRET:
assert segment_type.is_valid("secret") is True
elif segment_type == SegmentType.FILE:
assert segment_type.is_valid(create_test_file()) is True
elif segment_type == SegmentType.NONE:
assert segment_type.is_valid(None) is True
def test_boolean_vs_integer_type_distinction(self):
"""Test the important distinction between boolean and integer types in validation."""
# This tests the comment in the code about bool being a subclass of int
# Boolean type should only accept actual booleans, not integers
assert SegmentType.BOOLEAN.is_valid(True) is True
assert SegmentType.BOOLEAN.is_valid(False) is True
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
assert SegmentType.NUMBER.is_valid(42) is True
assert SegmentType.NUMBER.is_valid(3.14) is True
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
def test_array_validation_recursive_behavior(self):
"""Test that array validation correctly handles recursive validation calls."""
# When validating array elements, _validate_array calls is_valid recursively
# with ArrayValidation.NONE to avoid infinite recursion
# Test nested validation doesn't cause issues
nested_arrays = [["inner", "array"], ["another", "inner"]]
# ARRAY_ANY should accept nested arrays
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
# ARRAY_STRING should reject nested arrays (first element is not a string)
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False

View File

@ -1,6 +1,5 @@
import uuid
from collections.abc import Generator
from datetime import UTC, datetime
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
@ -15,6 +14,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
route_node_state = RouteNodeState(node_id=next_node_id, start_at=naive_utc_now())
parallel_id = graph.node_parallel_mapping.get(next_node_id)
parallel_start_node_id = None
@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
)
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None)
route_node_state.finished_at = naive_utc_now()
yield NodeRunSucceededEvent(
id=node_execution_id,
node_id=next_node_id,

View File

@ -0,0 +1,27 @@
from core.variables.types import SegmentType
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig
class TestParameterConfig:
def test_select_type(self):
data = {
"name": "yes_or_no",
"type": "select",
"options": ["yes", "no"],
"description": "a simple select made of `yes` and `no`",
"required": True,
}
pc = ParameterConfig.model_validate(data)
assert pc.type == SegmentType.STRING
assert pc.options == data["options"]
def test_validate_bool_type(self):
data = {
"name": "boolean",
"type": "bool",
"description": "a simple boolean parameter",
"required": True,
}
pc = ParameterConfig.model_validate(data)
assert pc.type == SegmentType.BOOLEAN

View File

@ -0,0 +1,567 @@
"""
Test cases for ParameterExtractorNode._validate_result and _transform_result methods.
"""
from dataclasses import dataclass
from typing import Any
import pytest
from core.model_runtime.entities import LLMMode
from core.variables.types import SegmentType
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData
from core.workflow.nodes.parameter_extractor.exc import (
InvalidNumberOfParametersError,
InvalidSelectValueError,
InvalidValueTypeError,
RequiredParameterMissingError,
)
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from factories.variable_factory import build_segment_with_type
@dataclass
class ValidTestCase:
"""Test case data for valid scenarios."""
name: str
parameters: list[ParameterConfig]
result: dict[str, Any]
def get_name(self) -> str:
return self.name
@dataclass
class ErrorTestCase:
"""Test case data for error scenarios."""
name: str
parameters: list[ParameterConfig]
result: dict[str, Any]
expected_exception: type[Exception]
expected_message: str
def get_name(self) -> str:
return self.name
@dataclass
class TransformTestCase:
"""Test case data for transformation scenarios."""
name: str
parameters: list[ParameterConfig]
input_result: dict[str, Any]
expected_result: dict[str, Any]
def get_name(self) -> str:
return self.name
class TestParameterExtractorNodeMethods:
"""Test helper class that provides access to the methods under test."""
def validate_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
"""Wrapper to call _validate_result method."""
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
return node._validate_result(data=data, result=result)
def transform_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
"""Wrapper to call _transform_result method."""
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
return node._transform_result(data=data, result=result)
class TestValidateResult:
"""Test cases for _validate_result method."""
@staticmethod
def get_valid_test_cases() -> list[ValidTestCase]:
"""Get test cases that should pass validation."""
return [
ValidTestCase(
name="single_string_parameter",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
result={"name": "John"},
),
ValidTestCase(
name="single_number_parameter_int",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
result={"age": 25},
),
ValidTestCase(
name="single_number_parameter_float",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
result={"price": 19.99},
),
ValidTestCase(
name="single_bool_parameter_true",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": True},
),
ValidTestCase(
name="single_bool_parameter_true",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": True},
),
ValidTestCase(
name="single_bool_parameter_false",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": False},
),
ValidTestCase(
name="select_parameter_valid_option",
parameters=[
ParameterConfig(
name="status",
type="select", # pyright: ignore[reportArgumentType]
description="Status",
required=True,
options=["active", "inactive"],
)
],
result={"status": "active"},
),
ValidTestCase(
name="array_string_parameter",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": ["tag1", "tag2", "tag3"]},
),
ValidTestCase(
name="array_number_parameter",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
result={"scores": [85, 92.5, 78]},
),
ValidTestCase(
name="array_object_parameter",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
result={"items": [{"name": "item1"}, {"name": "item2"}]},
),
ValidTestCase(
name="multiple_parameters",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
],
result={"name": "John", "age": 25, "active": True},
),
ValidTestCase(
name="optional_parameter_present",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="nickname", type=SegmentType.STRING, description="Nickname", required=False),
],
result={"name": "John", "nickname": "Johnny"},
),
ValidTestCase(
name="empty_array_parameter",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": []},
),
]
@staticmethod
def get_error_test_cases() -> list[ErrorTestCase]:
"""Get test cases that should raise exceptions."""
return [
ErrorTestCase(
name="invalid_number_of_parameters_too_few",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
],
result={"name": "John"},
expected_exception=InvalidNumberOfParametersError,
expected_message="Invalid number of parameters",
),
ErrorTestCase(
name="invalid_number_of_parameters_too_many",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
result={"name": "John", "age": 25},
expected_exception=InvalidNumberOfParametersError,
expected_message="Invalid number of parameters",
),
ErrorTestCase(
name="invalid_string_value_none",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
],
result={"name": None}, # Parameter present but None value, will trigger type check first
expected_exception=InvalidValueTypeError,
expected_message="Invalid value for parameter name, expected segment type: string, actual_type: none",
),
ErrorTestCase(
name="invalid_select_value",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
result={"status": "pending"},
expected_exception=InvalidSelectValueError,
expected_message="Invalid `select` value for parameter status",
),
ErrorTestCase(
name="invalid_number_value_string",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
result={"age": "twenty-five"},
expected_exception=InvalidValueTypeError,
expected_message="Invalid value for parameter age, expected segment type: number, actual_type: string",
),
ErrorTestCase(
name="invalid_bool_value_string",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
result={"active": "yes"},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter active, expected segment type: boolean, actual_type: string"
),
),
ErrorTestCase(
name="invalid_string_value_number",
parameters=[
ParameterConfig(
name="description", type=SegmentType.STRING, description="Description", required=True
)
],
result={"description": 123},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter description, expected segment type: string, actual_type: integer"
),
),
ErrorTestCase(
name="invalid_array_value_not_list",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": "tag1,tag2,tag3"},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter tags, expected segment type: array[string], actual_type: string"
),
),
ErrorTestCase(
name="invalid_array_number_wrong_element_type",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
result={"scores": [85, "ninety-two", 78]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter scores, expected segment type: array[number], actual_type: array[any]"
),
),
ErrorTestCase(
name="invalid_array_string_wrong_element_type",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
result={"tags": ["tag1", 123, "tag3"]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter tags, expected segment type: array[string], actual_type: array[any]"
),
),
ErrorTestCase(
name="invalid_array_object_wrong_element_type",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
result={"items": [{"name": "item1"}, "item2"]},
expected_exception=InvalidValueTypeError,
expected_message=(
"Invalid value for parameter items, expected segment type: array[object], actual_type: array[any]"
),
),
ErrorTestCase(
name="required_parameter_missing",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=False),
],
result={"age": 25, "other": "value"}, # Missing required 'name' parameter, but has correct count
expected_exception=RequiredParameterMissingError,
expected_message="Parameter name is required",
),
]
@pytest.mark.parametrize("test_case", get_valid_test_cases(), ids=ValidTestCase.get_name)
def test_validate_result_valid_cases(self, test_case):
"""Test _validate_result with valid inputs."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
result = helper.validate_result(data=node_data, result=test_case.result)
assert result == test_case.result, f"Failed for case: {test_case.name}"
@pytest.mark.parametrize("test_case", get_error_test_cases(), ids=ErrorTestCase.get_name)
def test_validate_result_error_cases(self, test_case):
"""Test _validate_result with invalid inputs that should raise exceptions."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
with pytest.raises(test_case.expected_exception) as exc_info:
helper.validate_result(data=node_data, result=test_case.result)
assert test_case.expected_message in str(exc_info.value), f"Failed for case: {test_case.name}"
class TestTransformResult:
"""Test cases for _transform_result method."""
@staticmethod
def get_transform_test_cases() -> list[TransformTestCase]:
"""Get test cases for result transformation."""
return [
# String parameter transformation
TransformTestCase(
name="string_parameter_present",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
input_result={"name": "John"},
expected_result={"name": "John"},
),
TransformTestCase(
name="string_parameter_missing",
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
input_result={},
expected_result={"name": ""},
),
# Number parameter transformation
TransformTestCase(
name="number_parameter_int_present",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": 25},
expected_result={"age": 25},
),
TransformTestCase(
name="number_parameter_float_present",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
input_result={"price": 19.99},
expected_result={"price": 19.99},
),
TransformTestCase(
name="number_parameter_missing",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={},
expected_result={"age": 0},
),
# Bool parameter transformation
TransformTestCase(
name="bool_parameter_missing",
parameters=[
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
],
input_result={},
expected_result={"active": False},
),
# Select parameter transformation
TransformTestCase(
name="select_parameter_present",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
input_result={"status": "active"},
expected_result={"status": "active"},
),
TransformTestCase(
name="select_parameter_missing",
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
description="Status",
required=True,
options=["active", "inactive"],
)
],
input_result={},
expected_result={"status": ""},
),
# Array parameter transformation - present cases
TransformTestCase(
name="array_string_parameter_present",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
input_result={"tags": ["tag1", "tag2"]},
expected_result={
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=["tag1", "tag2"])
},
),
TransformTestCase(
name="array_number_parameter_present",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, 92.5]},
expected_result={
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5])
},
),
TransformTestCase(
name="array_number_parameter_with_string_conversion",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, "92.5", "78"]},
expected_result={
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5, 78])
},
),
TransformTestCase(
name="array_object_parameter_present",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
input_result={"items": [{"name": "item1"}, {"name": "item2"}]},
expected_result={
"items": build_segment_with_type(
segment_type=SegmentType.ARRAY_OBJECT, value=[{"name": "item1"}, {"name": "item2"}]
)
},
),
# Array parameter transformation - missing cases
TransformTestCase(
name="array_string_parameter_missing",
parameters=[
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
],
input_result={},
expected_result={"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[])},
),
TransformTestCase(
name="array_number_parameter_missing",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={},
expected_result={"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[])},
),
TransformTestCase(
name="array_object_parameter_missing",
parameters=[
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
],
input_result={},
expected_result={"items": build_segment_with_type(segment_type=SegmentType.ARRAY_OBJECT, value=[])},
),
# Multiple parameters transformation
TransformTestCase(
name="multiple_parameters_mixed",
parameters=[
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True),
],
input_result={"name": "John", "age": 25},
expected_result={
"name": "John",
"age": 25,
"active": False,
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[]),
},
),
# Number parameter transformation with string conversion
TransformTestCase(
name="number_parameter_string_to_float",
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
input_result={"price": "19.99"},
expected_result={"price": 19.99}, # String not converted, falls back to default
),
TransformTestCase(
name="number_parameter_string_to_int",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": "25"},
expected_result={"age": 25}, # String not converted, falls back to default
),
TransformTestCase(
name="number_parameter_invalid_string",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": "invalid_number"},
expected_result={"age": 0}, # Invalid string conversion fails, falls back to default
),
TransformTestCase(
name="number_parameter_non_string_non_number",
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
input_result={"age": ["not_a_number"]}, # Non-string, non-number value
expected_result={"age": 0}, # Falls back to default
),
TransformTestCase(
name="array_number_parameter_with_invalid_string_conversion",
parameters=[
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
],
input_result={"scores": [85, "invalid", "78"]},
expected_result={
"scores": build_segment_with_type(
segment_type=SegmentType.ARRAY_NUMBER, value=[85, 78]
) # Invalid string skipped
},
),
]
@pytest.mark.parametrize("test_case", get_transform_test_cases(), ids=TransformTestCase.get_name)
def test_transform_result_cases(self, test_case):
"""Test _transform_result with various inputs."""
helper = TestParameterExtractorNodeMethods()
node_data = ParameterExtractorNodeData(
title="Test Node",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
query=["test_query"],
parameters=test_case.parameters,
reasoning_mode="function_call",
vision=VisionConfig(),
)
result = helper.transform_result(data=node_data, result=test_case.input_result)
assert result == test_case.expected_result, (
f"Failed for case: {test_case.name}. Expected: {test_case.expected_result}, Got: {result}"
)

View File

@ -2,6 +2,8 @@ import time
import uuid
from unittest.mock import MagicMock, Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
@ -272,3 +274,220 @@ def test_array_file_contains_file_name():
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
def _get_test_conditions() -> list:
conditions = [
# Test boolean "is" operator
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},
# Test boolean "is not" operator
{"comparison_operator": "is not", "variable_selector": ["start", "bool_false"], "value": "true"},
# Test boolean "=" operator
{"comparison_operator": "=", "variable_selector": ["start", "bool_true"], "value": "1"},
# Test boolean "≠" operator
{"comparison_operator": "", "variable_selector": ["start", "bool_false"], "value": "1"},
# Test boolean "not null" operator
{"comparison_operator": "not null", "variable_selector": ["start", "bool_true"]},
# Test boolean array "contains" operator
{"comparison_operator": "contains", "variable_selector": ["start", "bool_array"], "value": "true"},
# Test boolean "in" operator
{
"comparison_operator": "in",
"variable_selector": ["start", "bool_true"],
"value": ["true", "false"],
},
]
return [Condition.model_validate(i) for i in conditions]
def _get_condition_test_id(c: Condition):
return c.comparison_operator
@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id)
def test_execute_if_else_boolean_conditions(condition: Condition):
"""Test IfElseNode with boolean conditions using various operators"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
pool.add(["start", "bool_array"], [True, False, True])
pool.add(["start", "mixed_array"], [True, "false", 1, 0])
node_data = {
"title": "Boolean Test",
"type": "if-else",
"logical_operator": "and",
"conditions": [condition.model_dump()],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
def test_execute_if_else_boolean_false_conditions():
"""Test IfElseNode with boolean conditions that should evaluate to false"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
pool.add(["start", "bool_array"], [True, False, True])
node_data = {
"title": "Boolean False Test",
"type": "if-else",
"logical_operator": "or",
"conditions": [
# Test boolean "is" operator (should be false)
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "false"},
# Test boolean "=" operator (should be false)
{"comparison_operator": "=", "variable_selector": ["start", "bool_false"], "value": "1"},
# Test boolean "not contains" operator (should be false)
{
"comparison_operator": "not contains",
"variable_selector": ["start", "bool_array"],
"value": "true",
},
],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": node_data,
},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is False
def test_execute_if_else_boolean_cases_structure():
"""Test IfElseNode with boolean conditions using the new cases structure"""
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool with boolean values
pool = VariablePool(
system_variables=SystemVariable(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
pool.add(["start", "bool_false"], False)
node_data = {
"title": "Boolean Cases Test",
"type": "if-else",
"cases": [
{
"case_id": "true",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "is",
"variable_selector": ["start", "bool_true"],
"value": "true",
},
{
"comparison_operator": "is not",
"variable_selector": ["start", "bool_false"],
"value": "true",
},
],
}
],
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
assert result.outputs["selected_case_id"] == "true"

View File

@ -11,7 +11,8 @@ from core.workflow.nodes.list_operator.entities import (
FilterCondition,
Limit,
ListOperatorNodeData,
OrderBy,
Order,
OrderByConfig,
)
from core.workflow.nodes.list_operator.exc import InvalidKeyError
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
@ -27,7 +28,7 @@ def list_operator_node():
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
],
),
"order_by": OrderBy(enabled=False, value="asc"),
"order_by": OrderByConfig(enabled=False, value=Order.ASC),
"limit": Limit(enabled=False, size=0),
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",

View File

@ -1,5 +1,4 @@
import json
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
@ -23,6 +22,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import Workflow, WorkflowRun
@ -145,8 +145,8 @@ def real_workflow():
workflow.graph = json.dumps(graph_data)
workflow.features = json.dumps({"file_upload": {"enabled": False}})
workflow.created_by = "test-user-id"
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.created_at = naive_utc_now()
workflow.updated_at = naive_utc_now()
workflow._environment_variables = "{}"
workflow._conversation_variables = "{}"
@ -169,7 +169,7 @@ def real_workflow_run():
workflow_run.outputs = json.dumps({"answer": "test answer"})
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
workflow_run.created_by = "test-user-id"
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.created_at = naive_utc_now()
return workflow_run
@ -211,7 +211,7 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@ -245,7 +245,7 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@ -282,7 +282,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@ -335,7 +335,7 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@ -366,7 +366,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
event.start_at = datetime.now(UTC).replace(tzinfo=None)
event.start_at = naive_utc_now()
# Create a real node execution
@ -379,7 +379,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
created_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
)
# Pre-populate the cache with the node execution
@ -409,7 +409,7 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@ -443,7 +443,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
event.start_at = datetime.now(UTC).replace(tzinfo=None)
event.start_at = naive_utc_now()
event.error = "Test error message"
# Create a real node execution
@ -457,7 +457,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
created_at=datetime.now(UTC).replace(tzinfo=None),
created_at=naive_utc_now(),
)
# Pre-populate the cache with the node execution

View File

@ -59,7 +59,7 @@ def mock_response_receiver(monkeypatch) -> mock.Mock:
@pytest.fixture
def mock_logger(monkeypatch) -> logging.Logger:
_logger = mock.MagicMock(spec=logging.Logger)
monkeypatch.setattr(ext_request_logging, "_logger", _logger)
monkeypatch.setattr(ext_request_logging, "logger", _logger)
return _logger

View File

@ -24,16 +24,18 @@ from core.variables.segments import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from core.variables.types import SegmentType
from factories import variable_factory
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type
def test_string_variable():
@ -139,6 +141,26 @@ def test_array_number_variable():
assert isinstance(variable.value[1], float)
def test_build_segment_scalar_values():
@dataclass
class TestCase:
value: Any
expected: Segment
description: str
cases = [
TestCase(
value=True,
expected=BooleanSegment(value=True),
description="build_segment with boolean should yield BooleanSegment",
)
]
for idx, c in enumerate(cases, 1):
seg = build_segment(c.value)
assert seg == c.expected, f"Test case {idx} failed: {c.description}"
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
@ -847,15 +869,22 @@ class TestBuildSegmentValueErrors:
f"but got: {error_message}"
)
def test_build_segment_boolean_type_note(self):
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
# Boolean values in Python are subclasses of int, so they get processed as integers
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
def test_build_segment_boolean_type(self):
"""Test that Boolean values are correctly handled as boolean type, not integers."""
# Boolean values should now be processed as BooleanSegment, not IntegerSegment
# This is because the bool check now comes before the int check in build_segment
true_segment = variable_factory.build_segment(True)
false_segment = variable_factory.build_segment(False)
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.INTEGER
assert false_segment.value_type == SegmentType.INTEGER
# Verify they are processed as booleans, not integers
assert true_segment.value is True, "Test case 1 (boolean_true): Expected True to be processed as boolean True"
assert false_segment.value is False, (
"Test case 2 (boolean_false): Expected False to be processed as boolean False"
)
assert true_segment.value_type == SegmentType.BOOLEAN
assert false_segment.value_type == SegmentType.BOOLEAN
# Test array of booleans
bool_array_segment = variable_factory.build_segment([True, False, True])
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN
assert bool_array_segment.value == [True, False, True]

View File

@ -83,7 +83,7 @@ class TestDatasetPermissionService:
@pytest.fixture
def mock_logging_dependencies(self):
"""Mock setup for logging tests."""
with patch("services.dataset_service.logging") as mock_logging:
with patch("services.dataset_service.logger") as mock_logging:
yield {
"logging": mock_logging,
}

View File

@ -93,16 +93,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
with (
patch("services.dataset_service.DocumentService.get_document") as mock_get_doc,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.datetime") as mock_datetime,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
mock_naive_utc_now.return_value = current_time
yield {
"get_document": mock_get_doc,
"db_session": mock_db,
"datetime": mock_datetime,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
}
@ -120,21 +119,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
assert document.enabled == True
assert document.disabled_at is None
assert document.disabled_by is None
assert document.updated_at == current_time.replace(tzinfo=None)
assert document.updated_at == current_time
def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime):
"""Helper method to verify document was disabled correctly."""
assert document.enabled == False
assert document.disabled_at == current_time.replace(tzinfo=None)
assert document.disabled_at == current_time
assert document.disabled_by == user_id
assert document.updated_at == current_time.replace(tzinfo=None)
assert document.updated_at == current_time
def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime):
"""Helper method to verify document was archived correctly."""
assert document.archived == True
assert document.archived_at == current_time.replace(tzinfo=None)
assert document.archived_at == current_time
assert document.archived_by == user_id
assert document.updated_at == current_time.replace(tzinfo=None)
assert document.updated_at == current_time
def _assert_document_unarchived(self, document: Mock):
"""Helper method to verify document was unarchived correctly."""
@ -430,7 +429,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Verify document attributes were updated correctly
self._assert_document_unarchived(archived_doc)
assert archived_doc.updated_at == mock_document_service_dependencies["current_time"].replace(tzinfo=None)
assert archived_doc.updated_at == mock_document_service_dependencies["current_time"]
# Verify Redis cache was set (because document is enabled)
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
@ -495,9 +494,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Verify document was unarchived
self._assert_document_unarchived(archived_disabled_doc)
assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"].replace(
tzinfo=None
)
assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"]
# Verify no Redis cache was set (document is disabled)
redis_mock.setex.assert_not_called()

View File

@ -1,7 +1,7 @@
from unittest.mock import Mock, patch
import pytest
from flask_restful import reqparse
from flask_restx import reqparse
from werkzeug.exceptions import BadRequest
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs

View File

@ -1,7 +1,7 @@
from unittest.mock import Mock, patch
import pytest
from flask_restful import reqparse
from flask_restx import reqparse
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService

View File

@ -179,7 +179,7 @@ class TestDeleteDraftVariablesBatch:
delete_draft_variables_batch(app_id, 0)
@patch("tasks.remove_app_and_related_data_task.db")
@patch("tasks.remove_app_and_related_data_task.logging")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db):
"""Test that batch deletion logs progress correctly."""
app_id = "test-app-id"