mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 05:35:58 +08:00
feat: fix i18n missing keys and merge upstream/main (#24615)
Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com> Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: zhanluxianshen <zhanluxianshen@163.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: GuanMu <ballmanjq@gmail.com> Co-authored-by: Davide Delbianco <davide.delbianco@outlook.com> Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: Qiang Lee <18018968632@163.com> Co-authored-by: 李强04 <liqiang04@gaotu.cn> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Matri Qi <matrixdom@126.com> Co-authored-by: huayaoyue6 <huayaoyue@163.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: znn <jubinkumarsoni@gmail.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Muke Wang <shaodwaaron@gmail.com> Co-authored-by: wangmuke <wangmuke@kingsware.cn> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Eric Guo <eric.guocz@gmail.com> Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: jiangbo721 <jiangbo721@163.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: hjlarry <25834719+hjlarry@users.noreply.github.com> Co-authored-by: lxsummer <35754229+lxjustdoit@users.noreply.github.com> Co-authored-by: 湛露先生 <zhanluxianshen@163.com> Co-authored-by: Guangdong Liu <liugddx@gmail.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yessenia-d <yessenia.contact@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: 17hz <0x149527@gmail.com> Co-authored-by: Amy <1530140574@qq.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Nite Knite <nkCoding@gmail.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Co-authored-by: Petrus Han <petrus.hanks@gmail.com> Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com> Co-authored-by: Kalo Chin <frog.beepers.0n@icloud.com> Co-authored-by: Ujjwal Maurya <ujjwalsbx@gmail.com> Co-authored-by: Maries <xh001x@hotmail.com>
This commit is contained in:
@ -90,6 +90,7 @@ def test_flask_configs(monkeypatch):
|
||||
"pool_recycle": 3600,
|
||||
"pool_size": 30,
|
||||
"pool_use_lifo": False,
|
||||
"pool_reset_on_return": None,
|
||||
}
|
||||
|
||||
assert config["CONSOLE_WEB_URL"] == "https://example.com"
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from flask_restful import marshal
|
||||
from flask_restx import marshal
|
||||
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
@ -13,6 +12,7 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
@ -57,7 +57,7 @@ class TestWorkflowDraftVariableFields:
|
||||
)
|
||||
|
||||
sys_var.id = str(uuid.uuid4())
|
||||
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
sys_var.last_edited_at = naive_utc_now()
|
||||
sys_var.visible = True
|
||||
|
||||
expected_without_value = OrderedDict(
|
||||
@ -88,7 +88,7 @@ class TestWorkflowDraftVariableFields:
|
||||
)
|
||||
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
node_var.last_edited_at = naive_utc_now()
|
||||
|
||||
expected_without_value: OrderedDict[str, Any] = OrderedDict(
|
||||
{
|
||||
|
||||
@ -0,0 +1,134 @@
|
||||
"""Test authentication security to prevent user enumeration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
import services.errors.account
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
from controllers.console.auth.login import LoginApi
|
||||
from controllers.console.error import AccountNotFound
|
||||
|
||||
|
||||
class TestAuthenticationSecurity:
|
||||
"""Test authentication endpoints for security against user enumeration."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.app = Flask(__name__)
|
||||
self.api = Api(self.app)
|
||||
self.api.add_resource(LoginApi, "/login")
|
||||
self.client = self.app.test_client()
|
||||
self.app.config["TESTING"] = True
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email sends reset password email when registration is allowed."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
result = login_api.post()
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "fail", "data": "token123", "code": "account_not_found"}
|
||||
mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_wrong_password_returns_error(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
|
||||
):
|
||||
"""Test that wrong password returns AuthenticationFailedError."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("existing@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email raises AccountNotFound when registration is disabled."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AccountNotFound) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "account_not_found"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
with self.app.test_request_context("/reset-password", method="POST", json={"email": "existing@example.com"}):
|
||||
from controllers.console.auth.login import ResetPasswordSendEmailApi
|
||||
|
||||
api = ResetPasswordSendEmailApi()
|
||||
result = api.post()
|
||||
|
||||
assert result == {"result": "success", "data": "token123"}
|
||||
@ -0,0 +1,148 @@
|
||||
"""Tests for LLMUsage entity."""
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
|
||||
|
||||
class TestLLMUsage:
|
||||
"""Test cases for LLMUsage class."""
|
||||
|
||||
def test_from_metadata_with_all_tokens(self):
|
||||
"""Test from_metadata when all token types are provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_unit_price": 0.001,
|
||||
"completion_unit_price": 0.002,
|
||||
"total_price": 0.2,
|
||||
"currency": "USD",
|
||||
"latency": 1.5,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.prompt_unit_price == Decimal("0.001")
|
||||
assert usage.completion_unit_price == Decimal("0.002")
|
||||
assert usage.total_price == Decimal("0.2")
|
||||
assert usage.currency == "USD"
|
||||
assert usage.latency == 1.5
|
||||
|
||||
def test_from_metadata_with_prompt_tokens_only(self):
|
||||
"""Test from_metadata when only prompt_tokens is provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 100,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 100
|
||||
|
||||
def test_from_metadata_with_completion_tokens_only(self):
|
||||
"""Test from_metadata when only completion_tokens is provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 50,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 50
|
||||
|
||||
def test_from_metadata_calculates_total_when_missing(self):
|
||||
"""Test from_metadata calculates total_tokens when not provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150 # Should be calculated
|
||||
|
||||
def test_from_metadata_with_total_but_no_completion(self):
|
||||
"""
|
||||
Test from_metadata when total_tokens is provided but completion_tokens is 0.
|
||||
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
|
||||
"""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 479,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 521,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
# This is the key fix - prompt tokens should remain as prompt tokens
|
||||
assert usage.prompt_tokens == 479
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 521
|
||||
|
||||
def test_from_metadata_with_empty_metadata(self):
|
||||
"""Test from_metadata with empty metadata."""
|
||||
metadata: LLMUsageMetadata = {}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 0
|
||||
assert usage.currency == "USD"
|
||||
assert usage.latency == 0.0
|
||||
|
||||
def test_from_metadata_preserves_zero_completion_tokens(self):
|
||||
"""
|
||||
Test that zero completion_tokens are preserved when explicitly set.
|
||||
This is important for agent nodes that only use prompt tokens.
|
||||
"""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 1000,
|
||||
"prompt_unit_price": 0.15,
|
||||
"completion_unit_price": 0.60,
|
||||
"prompt_price": 0.00015,
|
||||
"completion_price": 0,
|
||||
"total_price": 0.00015,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 1000
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 1000
|
||||
assert usage.prompt_price == Decimal("0.00015")
|
||||
assert usage.completion_price == Decimal(0)
|
||||
assert usage.total_price == Decimal("0.00015")
|
||||
|
||||
def test_from_metadata_with_decimal_values(self):
|
||||
"""Test from_metadata handles decimal values correctly."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_unit_price": "0.001",
|
||||
"completion_unit_price": "0.002",
|
||||
"prompt_price": "0.1",
|
||||
"completion_price": "0.1",
|
||||
"total_price": "0.2",
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_unit_price == Decimal("0.001")
|
||||
assert usage.completion_unit_price == Decimal("0.002")
|
||||
assert usage.prompt_price == Decimal("0.1")
|
||||
assert usage.completion_price == Decimal("0.1")
|
||||
assert usage.total_price == Decimal("0.2")
|
||||
0
api/tests/unit_tests/core/plugin/__init__.py
Normal file
0
api/tests/unit_tests/core/plugin/__init__.py
Normal file
0
api/tests/unit_tests/core/plugin/utils/__init__.py
Normal file
0
api/tests/unit_tests/core/plugin/utils/__init__.py
Normal file
460
api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
Normal file
460
api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
Normal file
@ -0,0 +1,460 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class TestChunkMerger:
|
||||
def test_file_chunk_initialization(self):
|
||||
"""Test FileChunk initialization."""
|
||||
chunk = FileChunk(1024)
|
||||
assert chunk.bytes_written == 0
|
||||
assert chunk.total_length == 1024
|
||||
assert len(chunk.data) == 1024
|
||||
|
||||
def test_merge_blob_chunks_with_single_complete_chunk(self):
|
||||
"""Test merging a single complete blob chunk."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# First chunk (partial)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
|
||||
),
|
||||
)
|
||||
# Second chunk (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=10, blob=b"World", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# The buffer should contain the complete data
|
||||
assert result[0].message.blob[:10] == b"HelloWorld"
|
||||
|
||||
def test_merge_blob_chunks_with_multiple_files(self):
|
||||
"""Test merging chunks from multiple files."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# File 1, chunk 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=4, blob=b"AB", end=False
|
||||
),
|
||||
)
|
||||
# File 2, chunk 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file2", sequence=0, total_length=4, blob=b"12", end=False
|
||||
),
|
||||
)
|
||||
# File 1, chunk 2 (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=4, blob=b"CD", end=True
|
||||
),
|
||||
)
|
||||
# File 2, chunk 2 (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file2", sequence=1, total_length=4, blob=b"34", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 2
|
||||
# Check that both files are properly merged
|
||||
assert all(r.type == ToolInvokeMessage.MessageType.BLOB for r in result)
|
||||
|
||||
def test_merge_blob_chunks_passes_through_non_blob_messages(self):
|
||||
"""Test that non-blob messages pass through unchanged."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Text message
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="Hello"),
|
||||
)
|
||||
# Blob chunk
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=5, blob=b"Test", end=True
|
||||
),
|
||||
)
|
||||
# Another text message
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="World"),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 3
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.TextMessage)
|
||||
assert result[0].message.text == "Hello"
|
||||
assert result[1].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert result[2].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(result[2].message, ToolInvokeMessage.TextMessage)
|
||||
assert result[2].message.text == "World"
|
||||
|
||||
def test_merge_blob_chunks_file_too_large(self):
|
||||
"""Test that error is raised when file exceeds max size."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Send a chunk that would exceed the limit
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=100, blob=b"x" * 1024, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
list(merge_blob_chunks(mock_generator(), max_file_size=1000))
|
||||
assert "File is too large" in str(exc_info.value)
|
||||
|
||||
def test_merge_blob_chunks_chunk_too_large(self):
|
||||
"""Test that error is raised when chunk exceeds max chunk size."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Send a chunk that exceeds the max chunk size
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10000, blob=b"x" * 9000, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
list(merge_blob_chunks(mock_generator(), max_chunk_size=8192))
|
||||
assert "File chunk is too large" in str(exc_info.value)
|
||||
|
||||
def test_merge_blob_chunks_with_agent_invoke_message(self):
|
||||
"""Test that merge_blob_chunks works with AgentInvokeMessage."""
|
||||
|
||||
def mock_generator() -> Generator[AgentInvokeMessage, None, None]:
|
||||
# First chunk
|
||||
yield AgentInvokeMessage(
|
||||
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=AgentInvokeMessage.BlobChunkMessage(
|
||||
id="agent_file", sequence=0, total_length=8, blob=b"Agent", end=False
|
||||
),
|
||||
)
|
||||
# Final chunk
|
||||
yield AgentInvokeMessage(
|
||||
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=AgentInvokeMessage.BlobChunkMessage(
|
||||
id="agent_file", sequence=1, total_length=8, blob=b"Data", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], AgentInvokeMessage)
|
||||
assert result[0].type == AgentInvokeMessage.MessageType.BLOB
|
||||
|
||||
def test_merge_blob_chunks_preserves_meta(self):
|
||||
"""Test that meta information is preserved in merged messages."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=4, blob=b"Test", end=True
|
||||
),
|
||||
meta={"key": "value"},
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].meta == {"key": "value"}
|
||||
|
||||
def test_merge_blob_chunks_custom_limits(self):
|
||||
"""Test merge_blob_chunks with custom size limits."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# This should work with custom limits
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
|
||||
),
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=500, blob=b"y" * 100, end=True
|
||||
),
|
||||
)
|
||||
|
||||
# Should work with custom limits
|
||||
result = list(merge_blob_chunks(mock_generator(), max_file_size=1000, max_chunk_size=500))
|
||||
assert len(result) == 1
|
||||
|
||||
# Should fail with smaller file size limit
|
||||
def mock_generator2() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(merge_blob_chunks(mock_generator2(), max_file_size=300))
|
||||
|
||||
def test_merge_blob_chunks_data_integrity(self):
|
||||
"""Test that merged chunks exactly match the original data."""
|
||||
# Create original data
|
||||
original_data = b"This is a test message that will be split into chunks for testing purposes."
|
||||
chunk_size = 20
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Split original data into chunks
|
||||
chunks = []
|
||||
for i in range(0, len(original_data), chunk_size):
|
||||
chunk_data = original_data[i : i + chunk_size]
|
||||
is_last = (i + chunk_size) >= len(original_data)
|
||||
chunks.append((i // chunk_size, chunk_data, is_last))
|
||||
|
||||
# Yield chunks
|
||||
for sequence, data, is_end in chunks:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="test_file",
|
||||
sequence=sequence,
|
||||
total_length=len(original_data),
|
||||
blob=data,
|
||||
end=is_end,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# Verify the merged data exactly matches the original
|
||||
assert result[0].message.blob == original_data
|
||||
|
||||
def test_merge_blob_chunks_empty_chunk(self):
|
||||
"""Test handling of empty chunks."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# First chunk with data
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
|
||||
),
|
||||
)
|
||||
# Empty chunk in the middle
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=10, blob=b"", end=False
|
||||
),
|
||||
)
|
||||
# Final chunk with data
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=2, total_length=10, blob=b"World", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# The final blob should contain "Hello" followed by "World"
|
||||
assert result[0].message.blob[:10] == b"HelloWorld"
|
||||
|
||||
def test_merge_blob_chunks_single_chunk_file(self):
|
||||
"""Test file that arrives as a single complete chunk."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Single chunk that is both first and last
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="single_chunk_file",
|
||||
sequence=0,
|
||||
total_length=11,
|
||||
blob=b"Single Data",
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert result[0].message.blob == b"Single Data"
|
||||
|
||||
def test_merge_blob_chunks_concurrent_files(self):
|
||||
"""Test that chunks from different files are properly separated."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Interleave chunks from three different files
|
||||
files_data = {
|
||||
"file1": b"First file content",
|
||||
"file2": b"Second file data",
|
||||
"file3": b"Third file",
|
||||
}
|
||||
|
||||
# First chunk from each file
|
||||
for file_id, data in files_data.items():
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id=file_id,
|
||||
sequence=0,
|
||||
total_length=len(data),
|
||||
blob=data[:6],
|
||||
end=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Second chunk from each file (final)
|
||||
for file_id, data in files_data.items():
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id=file_id,
|
||||
sequence=1,
|
||||
total_length=len(data),
|
||||
blob=data[6:],
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 3
|
||||
|
||||
# Extract the blob data from results
|
||||
blobs = set()
|
||||
for r in result:
|
||||
assert isinstance(r.message, ToolInvokeMessage.BlobMessage)
|
||||
blobs.add(r.message.blob)
|
||||
expected = {b"First file content", b"Second file data", b"Third file"}
|
||||
assert blobs == expected
|
||||
|
||||
def test_merge_blob_chunks_exact_buffer_size(self):
|
||||
"""Test that data fitting exactly in buffer works correctly."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Create data that exactly fills the declared buffer
|
||||
exact_data = b"X" * 100
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="exact_file",
|
||||
sequence=0,
|
||||
total_length=100,
|
||||
blob=exact_data[:50],
|
||||
end=False,
|
||||
),
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="exact_file",
|
||||
sequence=1,
|
||||
total_length=100,
|
||||
blob=exact_data[50:],
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert len(result[0].message.blob) == 100
|
||||
assert result[0].message.blob == b"X" * 100
|
||||
|
||||
def test_merge_blob_chunks_large_file_simulation(self):
|
||||
"""Test handling of a large file split into many chunks."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Simulate a 1MB file split into 128 chunks of 8KB each
|
||||
chunk_size = 8192
|
||||
num_chunks = 128
|
||||
total_size = chunk_size * num_chunks
|
||||
|
||||
for i in range(num_chunks):
|
||||
# Create unique data for each chunk to verify ordering
|
||||
chunk_data = bytes([i % 256]) * chunk_size
|
||||
is_last = i == num_chunks - 1
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="large_file",
|
||||
sequence=i,
|
||||
total_length=total_size,
|
||||
blob=chunk_data,
|
||||
end=is_last,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert len(result[0].message.blob) == 1024 * 1024
|
||||
|
||||
# Verify the data pattern is correct
|
||||
merged_data = result[0].message.blob
|
||||
chunk_size = 8192
|
||||
num_chunks = 128
|
||||
for i in range(num_chunks):
|
||||
chunk_start = i * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
expected_byte = i % 256
|
||||
chunk = merged_data[chunk_start:chunk_end]
|
||||
assert all(b == expected_byte for b in chunk), f"Chunk {i} has incorrect data"
|
||||
|
||||
def test_merge_blob_chunks_sequential_order_required(self):
|
||||
"""
|
||||
Test note: The current implementation assumes chunks arrive in sequential order.
|
||||
Out-of-order chunks would need additional logic to handle properly.
|
||||
This test documents the expected behavior with sequential chunks.
|
||||
"""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Chunks arriving in correct sequential order
|
||||
data_parts = [b"First", b"Second", b"Third"]
|
||||
total_length = sum(len(part) for part in data_parts)
|
||||
|
||||
for i, part in enumerate(data_parts):
|
||||
is_last = i == len(data_parts) - 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="ordered_file",
|
||||
sequence=i,
|
||||
total_length=total_length,
|
||||
blob=part,
|
||||
end=is_last,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert result[0].message.blob == b"FirstSecondThird"
|
||||
@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -13,6 +12,7 @@ import pytest
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
@ -56,7 +56,7 @@ def sample_workflow_execution():
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
|
||||
@ -199,7 +199,7 @@ class TestCeleryWorkflowExecutionRepository:
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
@ -208,7 +208,7 @@ class TestCeleryWorkflowExecutionRepository:
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input2": "value2"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
@ -235,7 +235,7 @@ class TestCeleryWorkflowExecutionRepository:
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow node execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -18,6 +17,7 @@ from core.workflow.entities.workflow_node_execution import (
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@ -65,7 +65,7 @@ def sample_workflow_node_execution():
|
||||
title="Test Node",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
|
||||
@ -263,7 +263,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
|
||||
title="Node 1",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
@ -276,7 +276,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
|
||||
title="Node 2",
|
||||
inputs={"input2": "value2"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
@ -314,7 +314,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
|
||||
title="Node 2",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
@ -327,7 +327,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
|
||||
title="Node 1",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save in random order
|
||||
|
||||
@ -2,19 +2,19 @@
|
||||
Unit tests for the RepositoryFactory.
|
||||
|
||||
This module tests the factory pattern implementation for creating repository instances
|
||||
based on configuration, including error handling and validation.
|
||||
based on configuration, including error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from libs.module_loading import import_string
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
@ -23,98 +23,30 @@ from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
class TestRepositoryFactory:
|
||||
"""Test cases for RepositoryFactory."""
|
||||
|
||||
def test_import_class_success(self):
|
||||
def test_import_string_success(self):
|
||||
"""Test successful class import."""
|
||||
# Test importing a real class
|
||||
class_path = "unittest.mock.MagicMock"
|
||||
result = DifyCoreRepositoryFactory._import_class(class_path)
|
||||
result = import_string(class_path)
|
||||
assert result is MagicMock
|
||||
|
||||
def test_import_class_invalid_path(self):
|
||||
def test_import_string_invalid_path(self):
|
||||
"""Test import with invalid module path."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("invalid.module.path")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
import_string("invalid.module.path")
|
||||
assert "No module named" in str(exc_info.value)
|
||||
|
||||
def test_import_class_invalid_class_name(self):
|
||||
def test_import_string_invalid_class_name(self):
|
||||
"""Test import with invalid class name."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
import_string("unittest.mock.NonExistentClass")
|
||||
assert "does not define" in str(exc_info.value)
|
||||
|
||||
def test_import_class_malformed_path(self):
|
||||
def test_import_string_malformed_path(self):
|
||||
"""Test import with malformed path (no dots)."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("invalidpath")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_success(self):
|
||||
"""Test successful interface validation."""
|
||||
|
||||
# Create a mock class that implements the required methods
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface class
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception when all methods are present
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_repository_interface_missing_methods(self):
|
||||
"""Test interface validation with missing methods."""
|
||||
|
||||
# Create a mock class that's missing required methods
|
||||
class IncompleteRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
# Missing get_by_id method
|
||||
|
||||
# Create a mock interface that requires both methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def missing_method(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
|
||||
assert "does not implement required methods" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_with_private_methods(self):
|
||||
"""Test that private methods are ignored during interface validation."""
|
||||
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with private methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Should not raise exception - private methods should be ignored
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
import_string("invalidpath")
|
||||
assert "doesn't look like a module path" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_success(self, mock_config):
|
||||
@ -133,11 +65,8 @@ class TestRepositoryFactory:
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
):
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
@ -170,34 +99,7 @@ class TestRepositoryFactory:
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock the import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
|
||||
mocker.patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
|
||||
@ -212,11 +114,8 @@ class TestRepositoryFactory:
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock the validation methods to succeed
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
):
|
||||
# Mock import_string to return a failing class
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
@ -243,11 +142,8 @@ class TestRepositoryFactory:
|
||||
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
):
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
@ -280,34 +176,7 @@ class TestRepositoryFactory:
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowNodeExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock the import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
|
||||
mocker.patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
@ -322,11 +191,8 @@ class TestRepositoryFactory:
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock the validation methods to succeed
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
):
|
||||
# Mock import_string to return a failing class
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
@ -359,11 +225,8 @@ class TestRepositoryFactory:
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
):
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_engine, # Using Engine instead of sessionmaker
|
||||
user=mock_user,
|
||||
|
||||
308
api/tests/unit_tests/core/test_provider_configuration.py
Normal file
308
api/tests/unit_tests/core/test_provider_configuration.py
Normal file
@ -0,0 +1,308 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
ModelSettings,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
RestrictModel,
|
||||
SystemConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity():
|
||||
"""Mock provider entity with basic configuration"""
|
||||
provider_entity = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"),
|
||||
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
background="background.png",
|
||||
help=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
)
|
||||
|
||||
return provider_entity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_system_configuration():
|
||||
"""Mock system configuration"""
|
||||
quota_config = QuotaConfiguration(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=1000,
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)],
|
||||
)
|
||||
|
||||
system_config = SystemConfiguration(
|
||||
enabled=True,
|
||||
credentials={"openai_api_key": "test_key"},
|
||||
quota_configurations=[quota_config],
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
|
||||
return system_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_configuration():
|
||||
"""Mock custom configuration"""
|
||||
custom_config = CustomConfiguration(provider=None, models=[])
|
||||
return custom_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration):
|
||||
"""Create a test provider configuration instance"""
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
return ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfiguration:
|
||||
"""Test cases for ProviderConfiguration class"""
|
||||
|
||||
def test_get_current_credentials_system_provider_success(self, provider_configuration):
|
||||
"""Test successfully getting credentials from system provider"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_get_current_credentials_model_disabled(self, provider_configuration):
|
||||
"""Test getting credentials when model is disabled"""
|
||||
# Arrange
|
||||
model_setting = ModelSettings(
|
||||
model="gpt-4",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=False,
|
||||
load_balancing_configs=[],
|
||||
has_invalid_load_balancing_configs=False,
|
||||
)
|
||||
provider_configuration.model_settings = [model_setting]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Model gpt-4 is disabled"):
|
||||
provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
def test_get_current_credentials_custom_provider_with_models(self, provider_configuration):
|
||||
"""Test getting credentials from custom provider with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.CUSTOM
|
||||
|
||||
mock_model_config = Mock()
|
||||
mock_model_config.model_type = ModelType.LLM
|
||||
mock_model_config.model = "gpt-4"
|
||||
mock_model_config.credentials = {"openai_api_key": "custom_key"}
|
||||
provider_configuration.custom_configuration.models = [mock_model_config]
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "custom_key"}
|
||||
|
||||
def test_get_system_configuration_status_active(self, provider_configuration):
|
||||
"""Test getting active system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.ACTIVE
|
||||
|
||||
def test_get_system_configuration_status_unsupported(self, provider_configuration):
|
||||
"""Test getting unsupported system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
def test_get_system_configuration_status_quota_exceeded(self, provider_configuration):
|
||||
"""Test getting quota exceeded system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
quota_config = provider_configuration.system_configuration.quota_configurations[0]
|
||||
quota_config.is_valid = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
|
||||
def test_is_custom_configuration_available_with_provider(self, provider_configuration):
|
||||
"""Test custom configuration availability with provider credentials"""
|
||||
# Arrange
|
||||
mock_provider = Mock()
|
||||
mock_provider.available_credentials = ["openai_api_key"]
|
||||
provider_configuration.custom_configuration.provider = mock_provider
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_with_models(self, provider_configuration):
|
||||
"""Test custom configuration availability with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = [Mock()]
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_false(self, provider_configuration):
|
||||
"""Test custom configuration not available"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record successfully"""
|
||||
# Arrange
|
||||
mock_provider = Mock(spec=Provider)
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result == mock_provider
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record when not found"""
|
||||
# Arrange
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_init_with_customizable_model_only(
|
||||
self, mock_provider_entity, mock_system_configuration, mock_custom_configuration
|
||||
):
|
||||
"""Test initialization with customizable model only configuration"""
|
||||
# Arrange
|
||||
mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL]
|
||||
|
||||
# Act
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
config = ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods
|
||||
|
||||
def test_get_current_credentials_with_restricted_models(self, provider_configuration):
|
||||
"""Test getting credentials with model restrictions"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo")
|
||||
|
||||
# Assert
|
||||
assert credentials is not None
|
||||
assert "openai_api_key" in credentials
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_success(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential successfully"""
|
||||
# Arrange
|
||||
credential_id = "test_credential_id"
|
||||
mock_credential = Mock()
|
||||
mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}'
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential
|
||||
|
||||
# Act
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = {"openai_api_key": "test_key"}
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
|
||||
# Assert
|
||||
assert result == {"openai_api_key": "test_key"}
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential when not found"""
|
||||
# Arrange
|
||||
credential_id = "nonexistent_credential_id"
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = None
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
assert result is None
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
@ -1,190 +1,185 @@
|
||||
# from core.entities.provider_entities import ModelSettings
|
||||
# from core.model_runtime.entities.model_entities import ModelType
|
||||
# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
# from core.provider_manager import ProviderManager
|
||||
# from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
# def test__to_model_settings(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mock_entity.model_credential_schema = mocker.Mock()
|
||||
mock_entity.model_credential_schema.credential_form_schemas = []
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=True,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# ),
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id2",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="first",
|
||||
# encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
# enabled=True,
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(provider_entity,
|
||||
# provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 2
|
||||
# assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
# assert result[0].load_balancing_configs[1].name == "first"
|
||||
return mock_entity
|
||||
|
||||
|
||||
# def test__to_model_settings_only_one_lb(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
def test__to_model_settings(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=True,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# )
|
||||
# ]
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(
|
||||
# provider_entity, provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 0
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 2
|
||||
assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
# def test__to_model_settings_lb_disabled(mocker):
|
||||
# # Get all provider entities
|
||||
# model_provider_factory = ModelProviderFactory("test_tenant")
|
||||
# provider_entities = model_provider_factory.get_providers()
|
||||
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
)
|
||||
]
|
||||
|
||||
# provider_entity = None
|
||||
# for provider in provider_entities:
|
||||
# if provider.provider == "openai":
|
||||
# provider_entity = provider
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# # Mocking the inputs
|
||||
# provider_model_settings = [
|
||||
# ProviderModelSetting(
|
||||
# id="id",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# enabled=True,
|
||||
# load_balancing_enabled=False,
|
||||
# )
|
||||
# ]
|
||||
# load_balancing_model_configs = [
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id1",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="__inherit__",
|
||||
# encrypted_config=None,
|
||||
# enabled=True,
|
||||
# ),
|
||||
# LoadBalancingModelConfig(
|
||||
# id="id2",
|
||||
# tenant_id="tenant_id",
|
||||
# provider_name="openai",
|
||||
# model_name="gpt-4",
|
||||
# model_type="text-generation",
|
||||
# name="first",
|
||||
# encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
# enabled=True,
|
||||
# ),
|
||||
# ]
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# mocker.patch(
|
||||
# "core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
# return_value={"openai_api_key": "fake_key"}
|
||||
# )
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# provider_manager = ProviderManager()
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
# # Running the method
|
||||
# result = provider_manager._to_model_settings(provider_entity,
|
||||
# provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# # Asserting that the result is as expected
|
||||
# assert len(result) == 1
|
||||
# assert isinstance(result[0], ModelSettings)
|
||||
# assert result[0].model == "gpt-4"
|
||||
# assert result[0].model_type == ModelType.LLM
|
||||
# assert result[0].enabled is True
|
||||
# assert len(result[0].load_balancing_configs) == 0
|
||||
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
@ -23,6 +23,7 @@ class TestSegmentTypeIsArrayType:
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
@ -34,6 +35,7 @@ class TestSegmentTypeIsArrayType:
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
SegmentType.BOOLEAN,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
|
||||
@ -0,0 +1,729 @@
|
||||
"""
|
||||
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
|
||||
|
||||
This module provides thorough testing of the validation logic for all SegmentType values,
|
||||
including edge cases, error conditions, and different ArrayValidation strategies.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
def create_test_file(
|
||||
file_type: FileType = FileType.DOCUMENT,
|
||||
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||
filename: str = "test.txt",
|
||||
extension: str = ".txt",
|
||||
mime_type: str = "text/plain",
|
||||
size: int = 1024,
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing."""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationTestCase:
|
||||
"""Test case data structure for validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArrayValidationTestCase:
|
||||
"""Test case data structure for array validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
array_validation: ArrayValidation
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
# Test data construction functions
|
||||
def get_boolean_cases() -> list[ValidationTestCase]:
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
|
||||
# Invalid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_number_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid number values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
|
||||
# invalid number values
|
||||
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
|
||||
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
|
||||
]
|
||||
|
||||
|
||||
def get_string_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid string values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
|
||||
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
|
||||
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
|
||||
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
|
||||
# invalid values
|
||||
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_object_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid object values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
|
||||
]
|
||||
|
||||
|
||||
def get_secret_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid secret values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
|
||||
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_file_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid file values."""
|
||||
test_file = create_test_file()
|
||||
image_file = create_test_file(
|
||||
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
|
||||
)
|
||||
remote_file = create_test_file(
|
||||
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
|
||||
)
|
||||
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
|
||||
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
|
||||
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
|
||||
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
|
||||
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
|
||||
]
|
||||
|
||||
|
||||
def get_none_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid none values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
|
||||
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
|
||||
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
|
||||
]
|
||||
|
||||
|
||||
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_ANY validation."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"Mixed types with FIRST validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Mixed types with ALL validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", "world"],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Valid strings with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, 456],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Invalid elements with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["valid", 123, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", 123, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid, others invalid",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, "hello", "world"],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid, others valid",
|
||||
),
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
[float("inf"), float("-inf"), float("nan")],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Special float values",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "not an object"],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid object",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
["not an object", {"valid": "object"}],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid",
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{}, {"a": 1}, {"nested": {"key": "value"}}],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"All valid objects",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "invalid", {"another": "object"}],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_FILE validation with different strategies."""
|
||||
file1 = create_test_file(filename="file1.txt")
|
||||
file2 = create_test_file(filename="file2.txt")
|
||||
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
[True, "false", False],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element (string)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestSegmentTypeIsValid:
|
||||
"""Test suite for SegmentType.is_valid method covering all non-array types."""
|
||||
|
||||
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
|
||||
def test_boolean_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
|
||||
def test_number_validation(self, case: ValidationTestCase):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
|
||||
def test_string_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
|
||||
def test_object_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
|
||||
def test_secret_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
|
||||
def test_file_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
|
||||
def test_none_validation_valid_cases(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
def test_unsupported_segment_type_raises_assertion_error(self):
|
||||
"""Test that unsupported SegmentType values raise AssertionError."""
|
||||
# GROUP is not handled in is_valid method
|
||||
with pytest.raises(AssertionError, match="this statement should be unreachable"):
|
||||
SegmentType.GROUP.is_valid("any value")
|
||||
|
||||
|
||||
class TestSegmentTypeArrayValidation:
|
||||
"""Test suite for SegmentType._validate_array method and array type validation."""
|
||||
|
||||
def test_array_validation_non_list_values(self):
|
||||
"""Test that non-list values return False for all array types."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
non_list_values = [
|
||||
"not a list",
|
||||
123,
|
||||
3.14,
|
||||
True,
|
||||
None,
|
||||
{"key": "value"},
|
||||
create_test_file(),
|
||||
]
|
||||
|
||||
for array_type in array_types:
|
||||
for value in non_list_values:
|
||||
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
|
||||
|
||||
def test_empty_array_validation(self):
|
||||
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
|
||||
|
||||
for array_type in array_types:
|
||||
for strategy in validation_strategies:
|
||||
assert array_type.is_valid([], strategy) is True, (
|
||||
f"{array_type} should accept empty array with {strategy}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_any_validation(self, case):
|
||||
"""Test ARRAY_ANY validation accepts any list regardless of content."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_none_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_first_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_all_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_number_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_NUMBER validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_object_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_OBJECT validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_file_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_FILE validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_boolean_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
def test_default_array_validation_strategy(self):
|
||||
"""Test that default array validation strategy is FIRST."""
|
||||
# When no array_validation parameter is provided, it should default to FIRST
|
||||
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
|
||||
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
|
||||
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
|
||||
|
||||
def test_array_validation_edge_cases(self):
|
||||
"""Test edge cases for array validation."""
|
||||
# Test with nested arrays (should be invalid for specific array types)
|
||||
nested_array = [["nested", "array"], ["another", "nested"]]
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
|
||||
|
||||
# Test with very large arrays (performance consideration)
|
||||
large_valid_array = ["string"] * 1000
|
||||
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
|
||||
|
||||
|
||||
class TestSegmentTypeValidationIntegration:
|
||||
"""Integration tests for SegmentType validation covering interactions between methods."""
|
||||
|
||||
def test_non_array_types_ignore_array_validation_parameter(self):
|
||||
"""Test that non-array types ignore the array_validation parameter."""
|
||||
non_array_types = [
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
]
|
||||
|
||||
for segment_type in non_array_types:
|
||||
# Create appropriate valid value for each type
|
||||
valid_value: Any
|
||||
if segment_type == SegmentType.STRING:
|
||||
valid_value = "test"
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
valid_value = 42
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
valid_value = True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
valid_value = {"key": "value"}
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
valid_value = "secret"
|
||||
elif segment_type == SegmentType.FILE:
|
||||
valid_value = create_test_file()
|
||||
elif segment_type == SegmentType.NONE:
|
||||
valid_value = None
|
||||
else:
|
||||
continue # Skip unsupported types
|
||||
|
||||
# All array validation strategies should give the same result
|
||||
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
|
||||
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
|
||||
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
|
||||
|
||||
assert result_none == result_first == result_all == True, (
|
||||
f"{segment_type} should ignore array_validation parameter"
|
||||
)
|
||||
|
||||
def test_comprehensive_type_coverage(self):
|
||||
"""Test that all SegmentType enum values are covered in validation tests."""
|
||||
all_segment_types = set(SegmentType)
|
||||
|
||||
# Types that should be handled by is_valid method
|
||||
handled_types = {
|
||||
# Non-array types
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
}
|
||||
|
||||
# Types that are not handled by is_valid (should raise AssertionError)
|
||||
unhandled_types = {
|
||||
SegmentType.GROUP,
|
||||
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
||||
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
||||
}
|
||||
|
||||
# Verify all types are accounted for
|
||||
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
|
||||
|
||||
# Test that handled types work correctly
|
||||
for segment_type in handled_types:
|
||||
if segment_type.is_array_type():
|
||||
# Test with empty array (should always be valid)
|
||||
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
|
||||
else:
|
||||
# Test with appropriate valid value
|
||||
if segment_type == SegmentType.STRING:
|
||||
assert segment_type.is_valid("test") is True
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
assert segment_type.is_valid(42) is True
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
assert segment_type.is_valid(True) is True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
assert segment_type.is_valid({}) is True
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
assert segment_type.is_valid("secret") is True
|
||||
elif segment_type == SegmentType.FILE:
|
||||
assert segment_type.is_valid(create_test_file()) is True
|
||||
elif segment_type == SegmentType.NONE:
|
||||
assert segment_type.is_valid(None) is True
|
||||
|
||||
def test_boolean_vs_integer_type_distinction(self):
|
||||
"""Test the important distinction between boolean and integer types in validation."""
|
||||
# This tests the comment in the code about bool being a subclass of int
|
||||
|
||||
# Boolean type should only accept actual booleans, not integers
|
||||
assert SegmentType.BOOLEAN.is_valid(True) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(False) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
|
||||
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
|
||||
|
||||
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
|
||||
assert SegmentType.NUMBER.is_valid(42) is True
|
||||
assert SegmentType.NUMBER.is_valid(3.14) is True
|
||||
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
|
||||
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
|
||||
|
||||
def test_array_validation_recursive_behavior(self):
|
||||
"""Test that array validation correctly handles recursive validation calls."""
|
||||
# When validating array elements, _validate_array calls is_valid recursively
|
||||
# with ArrayValidation.NONE to avoid infinite recursion
|
||||
|
||||
# Test nested validation doesn't cause issues
|
||||
nested_arrays = [["inner", "array"], ["another", "inner"]]
|
||||
|
||||
# ARRAY_ANY should accept nested arrays
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
|
||||
|
||||
# ARRAY_STRING should reject nested arrays (first element is not a string)
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False
|
||||
@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -15,6 +14,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine
|
||||
|
||||
|
||||
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
|
||||
route_node_state = RouteNodeState(node_id=next_node_id, start_at=naive_utc_now())
|
||||
|
||||
parallel_id = graph.node_parallel_mapping.get(next_node_id)
|
||||
parallel_start_node_id = None
|
||||
@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
||||
)
|
||||
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||
route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
route_node_state.finished_at = naive_utc_now()
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
|
||||
@ -0,0 +1,27 @@
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig
|
||||
|
||||
|
||||
class TestParameterConfig:
|
||||
def test_select_type(self):
|
||||
data = {
|
||||
"name": "yes_or_no",
|
||||
"type": "select",
|
||||
"options": ["yes", "no"],
|
||||
"description": "a simple select made of `yes` and `no`",
|
||||
"required": True,
|
||||
}
|
||||
|
||||
pc = ParameterConfig.model_validate(data)
|
||||
assert pc.type == SegmentType.STRING
|
||||
assert pc.options == data["options"]
|
||||
|
||||
def test_validate_bool_type(self):
|
||||
data = {
|
||||
"name": "boolean",
|
||||
"type": "bool",
|
||||
"description": "a simple boolean parameter",
|
||||
"required": True,
|
||||
}
|
||||
pc = ParameterConfig.model_validate(data)
|
||||
assert pc.type == SegmentType.BOOLEAN
|
||||
@ -0,0 +1,567 @@
|
||||
"""
|
||||
Test cases for ParameterExtractorNode._validate_result and _transform_result methods.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData
|
||||
from core.workflow.nodes.parameter_extractor.exc import (
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidSelectValueError,
|
||||
InvalidValueTypeError,
|
||||
RequiredParameterMissingError,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidTestCase:
|
||||
"""Test case data for valid scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
result: dict[str, Any]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorTestCase:
|
||||
"""Test case data for error scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
result: dict[str, Any]
|
||||
expected_exception: type[Exception]
|
||||
expected_message: str
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformTestCase:
|
||||
"""Test case data for transformation scenarios."""
|
||||
|
||||
name: str
|
||||
parameters: list[ParameterConfig]
|
||||
input_result: dict[str, Any]
|
||||
expected_result: dict[str, Any]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class TestParameterExtractorNodeMethods:
|
||||
"""Test helper class that provides access to the methods under test."""
|
||||
|
||||
def validate_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Wrapper to call _validate_result method."""
|
||||
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
|
||||
return node._validate_result(data=data, result=result)
|
||||
|
||||
def transform_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Wrapper to call _transform_result method."""
|
||||
node = ParameterExtractorNode.__new__(ParameterExtractorNode)
|
||||
return node._transform_result(data=data, result=result)
|
||||
|
||||
|
||||
class TestValidateResult:
|
||||
"""Test cases for _validate_result method."""
|
||||
|
||||
@staticmethod
|
||||
def get_valid_test_cases() -> list[ValidTestCase]:
|
||||
"""Get test cases that should pass validation."""
|
||||
return [
|
||||
ValidTestCase(
|
||||
name="single_string_parameter",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
result={"name": "John"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_number_parameter_int",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
result={"age": 25},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_number_parameter_float",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
result={"price": 19.99},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_true",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_true",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="single_bool_parameter_false",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": False},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="select_parameter_valid_option",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # pyright: ignore[reportArgumentType]
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
result={"status": "active"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_string_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": ["tag1", "tag2", "tag3"]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_number_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
result={"scores": [85, 92.5, 78]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="array_object_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
result={"items": [{"name": "item1"}, {"name": "item2"}]},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="multiple_parameters",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
|
||||
],
|
||||
result={"name": "John", "age": 25, "active": True},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="optional_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="nickname", type=SegmentType.STRING, description="Nickname", required=False),
|
||||
],
|
||||
result={"name": "John", "nickname": "Johnny"},
|
||||
),
|
||||
ValidTestCase(
|
||||
name="empty_array_parameter",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": []},
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_error_test_cases() -> list[ErrorTestCase]:
|
||||
"""Get test cases that should raise exceptions."""
|
||||
return [
|
||||
ErrorTestCase(
|
||||
name="invalid_number_of_parameters_too_few",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
],
|
||||
result={"name": "John"},
|
||||
expected_exception=InvalidNumberOfParametersError,
|
||||
expected_message="Invalid number of parameters",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_number_of_parameters_too_many",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
result={"name": "John", "age": 25},
|
||||
expected_exception=InvalidNumberOfParametersError,
|
||||
expected_message="Invalid number of parameters",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_string_value_none",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
],
|
||||
result={"name": None}, # Parameter present but None value, will trigger type check first
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message="Invalid value for parameter name, expected segment type: string, actual_type: none",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_select_value",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
result={"status": "pending"},
|
||||
expected_exception=InvalidSelectValueError,
|
||||
expected_message="Invalid `select` value for parameter status",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_number_value_string",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
result={"age": "twenty-five"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message="Invalid value for parameter age, expected segment type: number, actual_type: string",
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_bool_value_string",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
result={"active": "yes"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter active, expected segment type: boolean, actual_type: string"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_string_value_number",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="description", type=SegmentType.STRING, description="Description", required=True
|
||||
)
|
||||
],
|
||||
result={"description": 123},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter description, expected segment type: string, actual_type: integer"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_value_not_list",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": "tag1,tag2,tag3"},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter tags, expected segment type: array[string], actual_type: string"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_number_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
result={"scores": [85, "ninety-two", 78]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter scores, expected segment type: array[number], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_string_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
result={"tags": ["tag1", 123, "tag3"]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter tags, expected segment type: array[string], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="invalid_array_object_wrong_element_type",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
result={"items": [{"name": "item1"}, "item2"]},
|
||||
expected_exception=InvalidValueTypeError,
|
||||
expected_message=(
|
||||
"Invalid value for parameter items, expected segment type: array[object], actual_type: array[any]"
|
||||
),
|
||||
),
|
||||
ErrorTestCase(
|
||||
name="required_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=False),
|
||||
],
|
||||
result={"age": 25, "other": "value"}, # Missing required 'name' parameter, but has correct count
|
||||
expected_exception=RequiredParameterMissingError,
|
||||
expected_message="Parameter name is required",
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_valid_test_cases(), ids=ValidTestCase.get_name)
|
||||
def test_validate_result_valid_cases(self, test_case):
|
||||
"""Test _validate_result with valid inputs."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
result = helper.validate_result(data=node_data, result=test_case.result)
|
||||
assert result == test_case.result, f"Failed for case: {test_case.name}"
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_error_test_cases(), ids=ErrorTestCase.get_name)
|
||||
def test_validate_result_error_cases(self, test_case):
|
||||
"""Test _validate_result with invalid inputs that should raise exceptions."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
with pytest.raises(test_case.expected_exception) as exc_info:
|
||||
helper.validate_result(data=node_data, result=test_case.result)
|
||||
|
||||
assert test_case.expected_message in str(exc_info.value), f"Failed for case: {test_case.name}"
|
||||
|
||||
|
||||
class TestTransformResult:
|
||||
"""Test cases for _transform_result method."""
|
||||
|
||||
@staticmethod
|
||||
def get_transform_test_cases() -> list[TransformTestCase]:
|
||||
"""Get test cases for result transformation."""
|
||||
return [
|
||||
# String parameter transformation
|
||||
TransformTestCase(
|
||||
name="string_parameter_present",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
input_result={"name": "John"},
|
||||
expected_result={"name": "John"},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="string_parameter_missing",
|
||||
parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)],
|
||||
input_result={},
|
||||
expected_result={"name": ""},
|
||||
),
|
||||
# Number parameter transformation
|
||||
TransformTestCase(
|
||||
name="number_parameter_int_present",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": 25},
|
||||
expected_result={"age": 25},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_float_present",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
input_result={"price": 19.99},
|
||||
expected_result={"price": 19.99},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_missing",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={},
|
||||
expected_result={"age": 0},
|
||||
),
|
||||
# Bool parameter transformation
|
||||
TransformTestCase(
|
||||
name="bool_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"active": False},
|
||||
),
|
||||
# Select parameter transformation
|
||||
TransformTestCase(
|
||||
name="select_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
input_result={"status": "active"},
|
||||
expected_result={"status": "active"},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="select_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"status": ""},
|
||||
),
|
||||
# Array parameter transformation - present cases
|
||||
TransformTestCase(
|
||||
name="array_string_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
input_result={"tags": ["tag1", "tag2"]},
|
||||
expected_result={
|
||||
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=["tag1", "tag2"])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, 92.5]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_with_string_conversion",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, "92.5", "78"]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5, 78])
|
||||
},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_object_parameter_present",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
input_result={"items": [{"name": "item1"}, {"name": "item2"}]},
|
||||
expected_result={
|
||||
"items": build_segment_with_type(
|
||||
segment_type=SegmentType.ARRAY_OBJECT, value=[{"name": "item1"}, {"name": "item2"}]
|
||||
)
|
||||
},
|
||||
),
|
||||
# Array parameter transformation - missing cases
|
||||
TransformTestCase(
|
||||
name="array_string_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[])},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[])},
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_object_parameter_missing",
|
||||
parameters=[
|
||||
ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True)
|
||||
],
|
||||
input_result={},
|
||||
expected_result={"items": build_segment_with_type(segment_type=SegmentType.ARRAY_OBJECT, value=[])},
|
||||
),
|
||||
# Multiple parameters transformation
|
||||
TransformTestCase(
|
||||
name="multiple_parameters_mixed",
|
||||
parameters=[
|
||||
ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True),
|
||||
ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True),
|
||||
ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True),
|
||||
ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True),
|
||||
],
|
||||
input_result={"name": "John", "age": 25},
|
||||
expected_result={
|
||||
"name": "John",
|
||||
"age": 25,
|
||||
"active": False,
|
||||
"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[]),
|
||||
},
|
||||
),
|
||||
# Number parameter transformation with string conversion
|
||||
TransformTestCase(
|
||||
name="number_parameter_string_to_float",
|
||||
parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)],
|
||||
input_result={"price": "19.99"},
|
||||
expected_result={"price": 19.99}, # String not converted, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_string_to_int",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": "25"},
|
||||
expected_result={"age": 25}, # String not converted, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_invalid_string",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": "invalid_number"},
|
||||
expected_result={"age": 0}, # Invalid string conversion fails, falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="number_parameter_non_string_non_number",
|
||||
parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)],
|
||||
input_result={"age": ["not_a_number"]}, # Non-string, non-number value
|
||||
expected_result={"age": 0}, # Falls back to default
|
||||
),
|
||||
TransformTestCase(
|
||||
name="array_number_parameter_with_invalid_string_conversion",
|
||||
parameters=[
|
||||
ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True)
|
||||
],
|
||||
input_result={"scores": [85, "invalid", "78"]},
|
||||
expected_result={
|
||||
"scores": build_segment_with_type(
|
||||
segment_type=SegmentType.ARRAY_NUMBER, value=[85, 78]
|
||||
) # Invalid string skipped
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("test_case", get_transform_test_cases(), ids=TransformTestCase.get_name)
|
||||
def test_transform_result_cases(self, test_case):
|
||||
"""Test _transform_result with various inputs."""
|
||||
helper = TestParameterExtractorNodeMethods()
|
||||
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="Test Node",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
query=["test_query"],
|
||||
parameters=test_case.parameters,
|
||||
reasoning_mode="function_call",
|
||||
vision=VisionConfig(),
|
||||
)
|
||||
|
||||
result = helper.transform_result(data=node_data, result=test_case.input_result)
|
||||
assert result == test_case.expected_result, (
|
||||
f"Failed for case: {test_case.name}. Expected: {test_case.expected_result}, Got: {result}"
|
||||
)
|
||||
@ -2,6 +2,8 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
@ -272,3 +274,220 @@ def test_array_file_contains_file_name():
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def _get_test_conditions() -> list:
|
||||
conditions = [
|
||||
# Test boolean "is" operator
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},
|
||||
# Test boolean "is not" operator
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "bool_false"], "value": "true"},
|
||||
# Test boolean "=" operator
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "bool_true"], "value": "1"},
|
||||
# Test boolean "≠" operator
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "bool_false"], "value": "1"},
|
||||
# Test boolean "not null" operator
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "bool_true"]},
|
||||
# Test boolean array "contains" operator
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "bool_array"], "value": "true"},
|
||||
# Test boolean "in" operator
|
||||
{
|
||||
"comparison_operator": "in",
|
||||
"variable_selector": ["start", "bool_true"],
|
||||
"value": ["true", "false"],
|
||||
},
|
||||
]
|
||||
return [Condition.model_validate(i) for i in conditions]
|
||||
|
||||
|
||||
def _get_condition_test_id(c: Condition):
|
||||
return c.comparison_operator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id)
|
||||
def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
"""Test IfElseNode with boolean conditions using various operators"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
pool.add(["start", "mixed_array"], [True, "false", 1, 0])
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Test",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [condition.model_dump()],
|
||||
}
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def test_execute_if_else_boolean_false_conditions():
|
||||
"""Test IfElseNode with boolean conditions that should evaluate to false"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
pool.add(["start", "bool_array"], [True, False, True])
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean False Test",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
# Test boolean "is" operator (should be false)
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "false"},
|
||||
# Test boolean "=" operator (should be false)
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "bool_false"], "value": "1"},
|
||||
# Test boolean "not contains" operator (should be false)
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "bool_array"],
|
||||
"value": "true",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data,
|
||||
},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is False
|
||||
|
||||
|
||||
def test_execute_if_else_boolean_cases_structure():
|
||||
"""Test IfElseNode with boolean conditions using the new cases structure"""
|
||||
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
pool.add(["start", "bool_false"], False)
|
||||
|
||||
node_data = {
|
||||
"title": "Boolean Cases Test",
|
||||
"type": "if-else",
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"variable_selector": ["start", "bool_true"],
|
||||
"value": "true",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "is not",
|
||||
"variable_selector": ["start", "bool_false"],
|
||||
"value": "true",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
assert result.outputs["selected_case_id"] == "true"
|
||||
|
||||
@ -11,7 +11,8 @@ from core.workflow.nodes.list_operator.entities import (
|
||||
FilterCondition,
|
||||
Limit,
|
||||
ListOperatorNodeData,
|
||||
OrderBy,
|
||||
Order,
|
||||
OrderByConfig,
|
||||
)
|
||||
from core.workflow.nodes.list_operator.exc import InvalidKeyError
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
|
||||
@ -27,7 +28,7 @@ def list_operator_node():
|
||||
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
|
||||
],
|
||||
),
|
||||
"order_by": OrderBy(enabled=False, value="asc"),
|
||||
"order_by": OrderByConfig(enabled=False, value=Order.ASC),
|
||||
"limit": Limit(enabled=False, size=0),
|
||||
"extract_by": ExtractConfig(enabled=False, serial="1"),
|
||||
"title": "Test Title",
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -23,6 +22,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
@ -145,8 +145,8 @@ def real_workflow():
|
||||
workflow.graph = json.dumps(graph_data)
|
||||
workflow.features = json.dumps({"file_upload": {"enabled": False}})
|
||||
workflow.created_by = "test-user-id"
|
||||
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.created_at = naive_utc_now()
|
||||
workflow.updated_at = naive_utc_now()
|
||||
workflow._environment_variables = "{}"
|
||||
workflow._conversation_variables = "{}"
|
||||
|
||||
@ -169,7 +169,7 @@ def real_workflow_run():
|
||||
workflow_run.outputs = json.dumps({"answer": "test answer"})
|
||||
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
|
||||
workflow_run.created_by = "test-user-id"
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.created_at = naive_utc_now()
|
||||
|
||||
return workflow_run
|
||||
|
||||
@ -211,7 +211,7 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
@ -245,7 +245,7 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
@ -282,7 +282,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
@ -335,7 +335,7 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
@ -366,7 +366,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
||||
event.process_data = {"process": "test process"}
|
||||
event.outputs = {"output": "test output"}
|
||||
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
|
||||
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
event.start_at = naive_utc_now()
|
||||
|
||||
# Create a real node execution
|
||||
|
||||
@ -379,7 +379,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the node execution
|
||||
@ -409,7 +409,7 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
@ -443,7 +443,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
||||
event.process_data = {"process": "test process"}
|
||||
event.outputs = {"output": "test output"}
|
||||
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
|
||||
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
event.start_at = naive_utc_now()
|
||||
event.error = "Test error message"
|
||||
|
||||
# Create a real node execution
|
||||
@ -457,7 +457,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the node execution
|
||||
|
||||
@ -59,7 +59,7 @@ def mock_response_receiver(monkeypatch) -> mock.Mock:
|
||||
@pytest.fixture
|
||||
def mock_logger(monkeypatch) -> logging.Logger:
|
||||
_logger = mock.MagicMock(spec=logging.Logger)
|
||||
monkeypatch.setattr(ext_request_logging, "_logger", _logger)
|
||||
monkeypatch.setattr(ext_request_logging, "logger", _logger)
|
||||
return _logger
|
||||
|
||||
|
||||
|
||||
@ -24,16 +24,18 @@ from core.variables.segments import (
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from factories import variable_factory
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
@ -139,6 +141,26 @@ def test_array_number_variable():
|
||||
assert isinstance(variable.value[1], float)
|
||||
|
||||
|
||||
def test_build_segment_scalar_values():
|
||||
@dataclass
|
||||
class TestCase:
|
||||
value: Any
|
||||
expected: Segment
|
||||
description: str
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
value=True,
|
||||
expected=BooleanSegment(value=True),
|
||||
description="build_segment with boolean should yield BooleanSegment",
|
||||
)
|
||||
]
|
||||
|
||||
for idx, c in enumerate(cases, 1):
|
||||
seg = build_segment(c.value)
|
||||
assert seg == c.expected, f"Test case {idx} failed: {c.description}"
|
||||
|
||||
|
||||
def test_array_object_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
@ -847,15 +869,22 @@ class TestBuildSegmentValueErrors:
|
||||
f"but got: {error_message}"
|
||||
)
|
||||
|
||||
def test_build_segment_boolean_type_note(self):
|
||||
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
|
||||
# Boolean values in Python are subclasses of int, so they get processed as integers
|
||||
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
|
||||
def test_build_segment_boolean_type(self):
|
||||
"""Test that Boolean values are correctly handled as boolean type, not integers."""
|
||||
# Boolean values should now be processed as BooleanSegment, not IntegerSegment
|
||||
# This is because the bool check now comes before the int check in build_segment
|
||||
true_segment = variable_factory.build_segment(True)
|
||||
false_segment = variable_factory.build_segment(False)
|
||||
|
||||
# Verify they are processed as integers, not as errors
|
||||
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
|
||||
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
|
||||
assert true_segment.value_type == SegmentType.INTEGER
|
||||
assert false_segment.value_type == SegmentType.INTEGER
|
||||
# Verify they are processed as booleans, not integers
|
||||
assert true_segment.value is True, "Test case 1 (boolean_true): Expected True to be processed as boolean True"
|
||||
assert false_segment.value is False, (
|
||||
"Test case 2 (boolean_false): Expected False to be processed as boolean False"
|
||||
)
|
||||
assert true_segment.value_type == SegmentType.BOOLEAN
|
||||
assert false_segment.value_type == SegmentType.BOOLEAN
|
||||
|
||||
# Test array of booleans
|
||||
bool_array_segment = variable_factory.build_segment([True, False, True])
|
||||
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN
|
||||
assert bool_array_segment.value == [True, False, True]
|
||||
|
||||
@ -83,7 +83,7 @@ class TestDatasetPermissionService:
|
||||
@pytest.fixture
|
||||
def mock_logging_dependencies(self):
|
||||
"""Mock setup for logging tests."""
|
||||
with patch("services.dataset_service.logging") as mock_logging:
|
||||
with patch("services.dataset_service.logger") as mock_logging:
|
||||
yield {
|
||||
"logging": mock_logging,
|
||||
}
|
||||
|
||||
@ -93,16 +93,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
with (
|
||||
patch("services.dataset_service.DocumentService.get_document") as mock_get_doc,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.datetime") as mock_datetime,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_datetime.datetime.now.return_value = current_time
|
||||
mock_datetime.UTC = datetime.UTC
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_document": mock_get_doc,
|
||||
"db_session": mock_db,
|
||||
"datetime": mock_datetime,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
}
|
||||
|
||||
@ -120,21 +119,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
assert document.enabled == True
|
||||
assert document.disabled_at is None
|
||||
assert document.disabled_by is None
|
||||
assert document.updated_at == current_time.replace(tzinfo=None)
|
||||
assert document.updated_at == current_time
|
||||
|
||||
def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime):
|
||||
"""Helper method to verify document was disabled correctly."""
|
||||
assert document.enabled == False
|
||||
assert document.disabled_at == current_time.replace(tzinfo=None)
|
||||
assert document.disabled_at == current_time
|
||||
assert document.disabled_by == user_id
|
||||
assert document.updated_at == current_time.replace(tzinfo=None)
|
||||
assert document.updated_at == current_time
|
||||
|
||||
def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime):
|
||||
"""Helper method to verify document was archived correctly."""
|
||||
assert document.archived == True
|
||||
assert document.archived_at == current_time.replace(tzinfo=None)
|
||||
assert document.archived_at == current_time
|
||||
assert document.archived_by == user_id
|
||||
assert document.updated_at == current_time.replace(tzinfo=None)
|
||||
assert document.updated_at == current_time
|
||||
|
||||
def _assert_document_unarchived(self, document: Mock):
|
||||
"""Helper method to verify document was unarchived correctly."""
|
||||
@ -430,7 +429,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
|
||||
# Verify document attributes were updated correctly
|
||||
self._assert_document_unarchived(archived_doc)
|
||||
assert archived_doc.updated_at == mock_document_service_dependencies["current_time"].replace(tzinfo=None)
|
||||
assert archived_doc.updated_at == mock_document_service_dependencies["current_time"]
|
||||
|
||||
# Verify Redis cache was set (because document is enabled)
|
||||
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
|
||||
@ -495,9 +494,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
|
||||
# Verify document was unarchived
|
||||
self._assert_document_unarchived(archived_disabled_doc)
|
||||
assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"].replace(
|
||||
tzinfo=None
|
||||
)
|
||||
assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"]
|
||||
|
||||
# Verify no Redis cache was set (document is disabled)
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restful import reqparse
|
||||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restful import reqparse
|
||||
from flask_restx import reqparse
|
||||
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
@ -179,7 +179,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
delete_draft_variables_batch(app_id, 0)
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
@patch("tasks.remove_app_and_related_data_task.logging")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db):
|
||||
"""Test that batch deletion logs progress correctly."""
|
||||
app_id = "test-app-id"
|
||||
|
||||
Reference in New Issue
Block a user