Merge branch 'main' into feat/mcp-06-18

This commit is contained in:
Novice
2025-10-13 13:54:01 +08:00
364 changed files with 7548 additions and 3282 deletions

View File

@ -143,7 +143,7 @@ class TestOAuthCallback:
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
account = MagicMock()
account.status = AccountStatus.ACTIVE.value
account.status = AccountStatus.ACTIVE
token_pair = MagicMock()
token_pair.access_token = "jwt_access_token"
@ -220,11 +220,11 @@ class TestOAuthCallback:
@pytest.mark.parametrize(
("account_status", "expected_redirect"),
[
(AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
(AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."),
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
(
AccountStatus.CLOSED.value,
AccountStatus.CLOSED,
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
),
],
@ -296,13 +296,13 @@ class TestOAuthCallback:
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_account = MagicMock()
mock_account.status = AccountStatus.PENDING.value
mock_account.status = AccountStatus.PENDING
mock_generate_account.return_value = mock_account
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
assert mock_account.status == AccountStatus.ACTIVE.value
assert mock_account.status == AccountStatus.ACTIVE
assert mock_account.initialized_at is not None
mock_db.session.commit.assert_called_once()
@ -352,7 +352,7 @@ class TestOAuthCallback:
# Create account with CLOSED status
closed_account = MagicMock()
closed_account.status = AccountStatus.CLOSED.value
closed_account.status = AccountStatus.CLOSED
closed_account.id = "123"
closed_account.name = "Closed Account"
mock_generate_account.return_value = closed_account

View File

@ -60,7 +60,7 @@ class TestAccountInitialization:
return "success"
# Act
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
result = protected_view()
# Assert
@ -77,7 +77,7 @@ class TestAccountInitialization:
return "success"
# Act & Assert
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with pytest.raises(AccountNotInitializedError):
protected_view()
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
return "member_added"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
@ -239,7 +239,7 @@ class TestRateLimiting:
return "knowledge_success"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
@ -271,7 +271,7 @@ class TestRateLimiting:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):

View File

@ -0,0 +1,722 @@
import json
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVector,
AlibabaCloudMySQLVectorConfig,
)
from core.rag.models.document import Document
try:
from mysql.connector import Error as MySQLError
except ImportError:
# Fallback for testing environments where mysql-connector-python might not be installed
class MySQLError(Exception):
def __init__(self, errno, msg):
self.errno = errno
self.msg = msg
super().__init__(msg)
class TestAlibabaCloudMySQLVector(unittest.TestCase):
def setUp(self):
self.config = AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
charset="utf8mb4",
)
self.collection_name = "test_collection"
# Sample documents for testing
self.sample_documents = [
Document(
page_content="This is a test document about AI.",
metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"},
),
Document(
page_content="Another document about machine learning.",
metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"},
),
]
# Sample embeddings
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_init(self, mock_pool_class):
"""Test AlibabaCloudMySQLVector initialization."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor for vector support check
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"}, # Version check
{"vector_support": True}, # Vector support check
]
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert alibabacloud_mysql_vector.collection_name == self.collection_name
assert alibabacloud_mysql_vector.table_name == self.collection_name.lower()
assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql"
assert alibabacloud_mysql_vector.distance_function == "cosine"
assert alibabacloud_mysql_vector.pool is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_collection(self, mock_redis, mock_pool_class):
"""Test collection creation."""
# Mock Redis operations
mock_redis.lock.return_value.__enter__ = MagicMock()
mock_redis.lock.return_value.__exit__ = MagicMock()
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"}, # Version check
{"vector_support": True}, # Vector support check
]
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
alibabacloud_mysql_vector._create_collection(768)
# Verify SQL execution calls - should include table creation and index creation
assert mock_cursor.execute.called
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
mock_redis.set.assert_called_once()
def test_config_validation(self):
"""Test configuration validation."""
# Test missing required fields
with pytest.raises(ValueError):
AlibabaCloudMySQLVectorConfig(
host="", # Empty host should raise error
port=3306,
user="test",
password="test",
database="test",
max_connection=5,
)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_success(self, mock_pool_class):
"""Test successful vector support check."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
# Should not raise an exception
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert vector_store is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_failure(self, mock_pool_class):
"""Test vector support check failure."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}]
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVector(self.collection_name, self.config)
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_function_error(self, mock_pool_class):
"""Test vector support check with function not found error."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"}
mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")]
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVector(self.collection_name, self.config)
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_documents(self, mock_redis, mock_pool_class):
"""Test creating documents with embeddings."""
# Setup mocks
self._setup_mocks(mock_redis, mock_pool_class)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
result = vector_store.create(self.sample_documents, self.sample_embeddings)
assert len(result) == 2
assert "doc1" in result
assert "doc2" in result
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_add_texts(self, mock_pool_class):
"""Test adding texts to the vector store."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
result = vector_store.add_texts(self.sample_documents, self.sample_embeddings)
assert len(result) == 2
mock_cursor.executemany.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_exists(self, mock_pool_class):
"""Test checking if text exists."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"},
{"vector_support": True},
{"id": "doc1"}, # Text exists
]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
exists = vector_store.text_exists("doc1")
assert exists
# Check that the correct SQL was executed (last call after init)
execute_calls = mock_cursor.execute.call_args_list
last_call = execute_calls[-1]
assert "SELECT id FROM" in last_call[0][0]
assert last_call[0][1] == ("doc1",)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_not_exists(self, mock_pool_class):
"""Test checking if text does not exist."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"},
{"vector_support": True},
None, # Text does not exist
]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
exists = vector_store.text_exists("nonexistent")
assert not exists
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids(self, mock_pool_class):
"""Test getting documents by IDs."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"},
{"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"},
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.get_by_ids(["doc1", "doc2"])
assert len(docs) == 2
assert docs[0].page_content == "Test document 1"
assert docs[1].page_content == "Test document 2"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids_empty_list(self, mock_pool_class):
"""Test getting documents with empty ID list."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.get_by_ids([])
assert len(docs) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids(self, mock_pool_class):
"""Test deleting documents by IDs."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_ids(["doc1", "doc2"])
# Check that delete SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 1
delete_call = delete_calls[0]
assert "DELETE FROM" in delete_call[0][0]
assert delete_call[0][1] == ["doc1", "doc2"]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_empty_list(self, mock_pool_class):
"""Test deleting with empty ID list."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_ids([]) # Should not raise an exception
# Verify no delete SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
"""Test deleting when table doesn't exist."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
# Simulate table doesn't exist error on delete
def execute_side_effect(*args, **kwargs):
if "DELETE" in args[0]:
raise MySQLError(errno=1146, msg="Table doesn't exist")
mock_cursor.execute.side_effect = execute_side_effect
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
# Should not raise an exception
vector_store.delete_by_ids(["doc1"])
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_metadata_field(self, mock_pool_class):
"""Test deleting documents by metadata field."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_metadata_field("document_id", "dataset1")
# Check that the correct SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 1
delete_call = delete_calls[0]
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
assert delete_call[0][1] == ("$.document_id", "dataset1")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_cosine(self, mock_pool_class):
"""Test vector search with cosine distance."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5)
assert len(docs) == 1
assert docs[0].page_content == "Test document 1"
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
assert docs[0].metadata["distance"] == 0.1
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_euclidean(self, mock_pool_class):
"""Test vector search with euclidean distance."""
config = AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
distance_function="euclidean",
)
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5)
assert len(docs) == 1
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_filter(self, mock_pool_class):
"""Test vector search with document ID filter."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter([])
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"])
# Verify the SQL contains the WHERE clause for filtering
execute_calls = mock_cursor.execute.call_args_list
search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)]
assert len(search_calls) > 0
search_call = search_calls[0]
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
"""Test vector search with score threshold."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{
"meta": json.dumps({"doc_id": "doc1", "source": "test"}),
"text": "High similarity document",
"distance": 0.1, # High similarity (score = 0.9)
},
{
"meta": json.dumps({"doc_id": "doc2", "source": "test"}),
"text": "Low similarity document",
"distance": 0.8, # Low similarity (score = 0.2)
},
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5)
# Only the high similarity document should be returned
assert len(docs) == 1
assert docs[0].page_content == "High similarity document"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
"""Test vector search with invalid top_k."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k=0)
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text(self, mock_pool_class):
"""Test full-text search."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{
"meta": {"doc_id": "doc1", "source": "test"},
"text": "This document contains machine learning content",
"score": 1.5,
}
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.search_by_full_text("machine learning", top_k=5)
assert len(docs) == 1
assert docs[0].page_content == "This document contains machine learning content"
assert docs[0].metadata["score"] == 1.5
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_with_filter(self, mock_pool_class):
"""Test full-text search with document ID filter."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter([])
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"])
# Verify the SQL contains the AND clause for filtering
execute_calls = mock_cursor.execute.call_args_list
search_calls = [call for call in execute_calls if "MATCH" in str(call)]
assert len(search_calls) > 0
search_call = search_calls[0]
assert "AND JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
"""Test full-text search with invalid top_k."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k=0)
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_collection(self, mock_pool_class):
"""Test deleting the entire collection."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete()
# Check that DROP TABLE SQL was executed
execute_calls = mock_cursor.execute.call_args_list
drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)]
assert len(drop_calls) == 1
drop_call = drop_calls[0]
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_unsupported_distance_function(self, mock_pool_class):
"""Test that Pydantic validation rejects unsupported distance functions."""
# Test that creating config with unsupported distance function raises ValidationError
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"]
)
# The error should be related to validation
assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value)
def _setup_mocks(self, mock_redis, mock_pool_class):
"""Helper method to setup common mocks."""
# Mock Redis operations
mock_redis.lock.return_value.__enter__ = MagicMock()
mock_redis.lock.return_value.__exit__ = MagicMock()
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
if __name__ == "__main__":
unittest.main()

View File

@ -11,8 +11,8 @@ def test_default_value():
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
MilvusConfig.model_validate(config)
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig(**valid_config)
config = MilvusConfig.model_validate(valid_config)
assert config.database == "default"

View File

@ -1,10 +1,12 @@
import os
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"
@ -18,7 +20,7 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
mocked_firecrawl = {
"id": "test",
}
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
assert job_id is not None

View File

@ -1,5 +1,7 @@
from unittest import mock
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor
user_id = "user1"
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
return text.strip()
def test_notion_page(mocker):
def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
@ -69,7 +71,7 @@ def test_notion_page(mocker):
],
"next_cursor": None,
}
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
@ -77,14 +79,14 @@ def test_notion_page(mocker):
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None,
}
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)

View File

@ -140,7 +140,7 @@ class TestCeleryWorkflowExecutionRepository:
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify no task tracking occurs (no _pending_saves attribute)

View File

@ -149,7 +149,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify execution is cached

View File

@ -145,12 +145,12 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node-id"
db_model.node_type = NodeType.LLM.value
db_model.node_type = NodeType.LLM
db_model.title = "Test Node"
db_model.inputs = json.dumps({"value": "inputs"})
db_model.process_data = json.dumps({"value": "process_data"})
db_model.outputs = json.dumps({"value": "outputs"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 1.0
db_model.execution_metadata = "{}"

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
@ -39,7 +40,7 @@ def lb_model_manager():
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())

View File

@ -14,7 +14,13 @@ from core.entities.provider_entities import (
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormOption,
FormType,
ProviderEntity,
)
from models.provider import Provider, ProviderType
@ -306,3 +312,174 @@ class TestProviderConfiguration:
# Assert
assert credentials == {"openai_api_key": "test_key"}
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
"""Test extracting secret variables from credential form schemas"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 2
assert "api_key" in secret_variables
assert "secret_token" in secret_variables
assert "model_name" not in secret_variables
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
"""Test extracting secret variables when no secret input fields exist"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.SELECT,
required=True,
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
def test_extract_secret_variables_empty_list(self, provider_configuration):
"""Test extracting secret variables from empty credential form schemas"""
# Arrange
credential_form_schemas = []
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
@patch("core.entities.provider_configuration.encrypter")
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
"""Test obfuscating credentials with secret variables"""
# Arrange
credentials = {
"api_key": "sk-1234567890abcdef",
"model_name": "gpt-4",
"secret_token": "secret_value_123",
"temperature": "0.7",
}
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
]
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated["api_key"] == "***cdef"
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
assert obfuscated["secret_token"] == "***_123"
assert obfuscated["temperature"] == "0.7" # Not obfuscated
# Verify encrypter was called for secret fields only
assert mock_encrypter.obfuscated_token.call_count == 2
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
"""Test obfuscating credentials when no secret variables exist"""
# Arrange
credentials = {
"model_name": "gpt-4",
"temperature": "0.7",
"max_tokens": "1000",
}
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="max_tokens",
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
type=FormType.TEXT_INPUT,
required=True,
),
]
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == credentials # No changes expected
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
"""Test obfuscating empty credentials"""
# Arrange
credentials = {}
credential_form_schemas = []
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == {}

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker):
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
mock_entity.model_credential_schema = mocker.Mock()
mock_entity.model_credential_schema.credential_form_schemas = []
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker, mock_provider_entity):
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(

View File

@ -147,7 +147,7 @@ class TestRedisChannel:
"""Test deserializing an abort command."""
channel = RedisChannel(MagicMock(), "test:key")
abort_data = {"command_type": CommandType.ABORT.value}
abort_data = {"command_type": CommandType.ABORT}
command = channel._deserialize_command(abort_data)
assert isinstance(command, AbortCommand)
@ -158,7 +158,7 @@ class TestRedisChannel:
channel = RedisChannel(MagicMock(), "test:key")
# For now, only ABORT is supported, but test generic handling
generic_data = {"command_type": CommandType.ABORT.value}
generic_data = {"command_type": CommandType.ABORT}
command = channel._deserialize_command(generic_data)
assert command is not None

View File

@ -56,8 +56,8 @@ def test_mock_iteration_node_preserves_config():
workflow_id="test",
graph_config={"nodes": [], "edges": []},
user_id="test",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@ -117,8 +117,8 @@ def test_mock_loop_node_preserves_config():
workflow_id="test",
graph_config={"nodes": [], "edges": []},
user_id="test",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -49,7 +49,7 @@ class TestRedisStopIntegration:
# Verify the command data
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT.value
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "Test stop"
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
@ -122,7 +122,7 @@ class TestRedisStopIntegration:
# Verify serialized command
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT.value
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "User requested stop"
# Check expire was set
@ -137,9 +137,7 @@ class TestRedisStopIntegration:
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
# Mock command data
abort_command_json = json.dumps(
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
)
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
# Mock pipeline execute to return commands
mock_pipeline.execute.return_value = [

View File

@ -35,7 +35,7 @@ def list_operator_node():
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node_data = ListOperatorNodeData.model_validate(config)
node_config = {
"id": "test_node_id",
"data": node_data.model_dump(),

View File

@ -17,7 +17,7 @@ def test_init_question_classifier_node_data():
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
}
node_data = QuestionClassifierNodeData(**data)
node_data = QuestionClassifierNodeData.model_validate(data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
@ -49,7 +49,7 @@ def test_init_question_classifier_node_data_without_vision_config():
},
}
node_data = QuestionClassifierNodeData(**data)
node_data = QuestionClassifierNodeData.model_validate(data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"

View File

@ -87,7 +87,7 @@ def test_overwrite_string_variable():
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"write_mode": WriteMode.OVER_WRITE,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
@ -189,7 +189,7 @@ def test_append_variable_to_array():
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"write_mode": WriteMode.APPEND,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
@ -282,7 +282,7 @@ def test_clear_array():
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"write_mode": WriteMode.CLEAR,
"input_variable_selector": [],
},
}

View File

@ -46,7 +46,7 @@ class TestSystemVariableSerialization:
def test_basic_deserialization(self):
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
# Test with complete data
system_var = SystemVariable(**COMPLETE_VALID_DATA)
system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Verify all fields are correctly mapped
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
@ -59,7 +59,7 @@ class TestSystemVariableSerialization:
assert system_var.files == []
# Test with minimal data (only required fields)
minimal_var = SystemVariable(**VALID_BASE_DATA)
minimal_var = SystemVariable.model_validate(VALID_BASE_DATA)
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
@ -75,12 +75,12 @@ class TestSystemVariableSerialization:
# Test workflow_run_id only (preferred alias)
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var1 = SystemVariable(**data_run_id)
system_var1 = SystemVariable.model_validate(data_run_id)
assert system_var1.workflow_execution_id == workflow_id
# Test workflow_execution_id only (direct field name)
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var2 = SystemVariable(**data_execution_id)
system_var2 = SystemVariable.model_validate(data_execution_id)
assert system_var2.workflow_execution_id == workflow_id
# Test both present - workflow_run_id should take precedence
@ -89,17 +89,17 @@ class TestSystemVariableSerialization:
"workflow_execution_id": "should-be-ignored",
"workflow_run_id": workflow_id,
}
system_var3 = SystemVariable(**data_both)
system_var3 = SystemVariable.model_validate(data_both)
assert system_var3.workflow_execution_id == workflow_id
# Test neither present - should be None
system_var4 = SystemVariable(**VALID_BASE_DATA)
system_var4 = SystemVariable.model_validate(VALID_BASE_DATA)
assert system_var4.workflow_execution_id is None
def test_serialization_round_trip(self):
"""Test that serialize → deserialize produces the same result with alias handling."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Serialize to dict
serialized = original.model_dump(mode="json")
@ -110,7 +110,7 @@ class TestSystemVariableSerialization:
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize back
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
# Verify all fields match after round-trip
assert deserialized.user_id == original.user_id
@ -125,7 +125,7 @@ class TestSystemVariableSerialization:
def test_json_round_trip(self):
"""Test JSON serialization/deserialization consistency with proper structure."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Serialize to JSON string
json_str = original.model_dump_json()
@ -137,7 +137,7 @@ class TestSystemVariableSerialization:
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize from JSON data
deserialized = SystemVariable(**json_data)
deserialized = SystemVariable.model_validate(json_data)
# Verify key fields match after JSON round-trip
assert deserialized.workflow_execution_id == original.workflow_execution_id
@ -149,13 +149,13 @@ class TestSystemVariableSerialization:
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
# Test with empty files list
data_empty = {**VALID_BASE_DATA, "files": []}
system_var_empty = SystemVariable(**data_empty)
system_var_empty = SystemVariable.model_validate(data_empty)
assert system_var_empty.files == []
# Test with single File object
test_file = create_test_file()
data_single = {**VALID_BASE_DATA, "files": [test_file]}
system_var_single = SystemVariable(**data_single)
system_var_single = SystemVariable.model_validate(data_single)
assert len(system_var_single.files) == 1
assert system_var_single.files[0].filename == "test.txt"
assert system_var_single.files[0].tenant_id == "test-tenant-id"
@ -179,14 +179,14 @@ class TestSystemVariableSerialization:
)
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
system_var_multiple = SystemVariable(**data_multiple)
system_var_multiple = SystemVariable.model_validate(data_multiple)
assert len(system_var_multiple.files) == 2
assert system_var_multiple.files[0].filename == "doc1.txt"
assert system_var_multiple.files[1].filename == "image.jpg"
# Verify files field serialization/deserialization
serialized = system_var_multiple.model_dump(mode="json")
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
assert len(deserialized.files) == 2
assert deserialized.files[0].filename == "doc1.txt"
assert deserialized.files[1].filename == "image.jpg"
@ -197,7 +197,7 @@ class TestSystemVariableSerialization:
# Create with workflow_run_id (alias)
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var = SystemVariable(**data_with_alias)
system_var = SystemVariable.model_validate(data_with_alias)
# Serialize and verify alias is used
serialized = system_var.model_dump()
@ -205,7 +205,7 @@ class TestSystemVariableSerialization:
assert "workflow_execution_id" not in serialized
# Deserialize and verify field mapping
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
assert deserialized.workflow_execution_id == workflow_id
# Test JSON serialization path
@ -213,7 +213,7 @@ class TestSystemVariableSerialization:
assert json_serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in json_serialized
json_deserialized = SystemVariable(**json_serialized)
json_deserialized = SystemVariable.model_validate(json_serialized)
assert json_deserialized.workflow_execution_id == workflow_id
def test_model_validator_serialization_logic(self):
@ -222,7 +222,7 @@ class TestSystemVariableSerialization:
# Test direct instantiation with workflow_execution_id (should work)
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var1 = SystemVariable(**data1)
system_var1 = SystemVariable.model_validate(data1)
assert system_var1.workflow_execution_id == workflow_id
# Test serialization of the above (should use alias)
@ -236,7 +236,7 @@ class TestSystemVariableSerialization:
"workflow_execution_id": "should-be-removed",
"workflow_run_id": workflow_id,
}
system_var2 = SystemVariable(**data2)
system_var2 = SystemVariable.model_validate(data2)
assert system_var2.workflow_execution_id == workflow_id
# Verify serialization consistency

View File

@ -11,7 +11,7 @@ class TestExtractTenantId:
def test_extract_tenant_id_from_account_with_tenant(self):
"""Test extracting tenant_id from Account with current_tenant_id."""
# Create a mock Account object
account = Account()
account = Account(name="test", email="test@example.com")
# Mock the current_tenant_id property
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
@ -21,7 +21,7 @@ class TestExtractTenantId:
def test_extract_tenant_id_from_account_without_tenant(self):
"""Test extracting tenant_id from Account without current_tenant_id."""
# Create a mock Account object
account = Account()
account = Account(name="test", email="test@example.com")
account._current_tenant = None
tenant_id = extract_tenant_id(account)

View File

@ -59,12 +59,11 @@ def session():
@pytest.fixture
def mock_user():
"""Create a user instance for testing."""
user = Account()
user = Account(name="test", email="test@example.com")
user.id = "test-user-id"
tenant = Tenant()
tenant = Tenant(name="Test Workspace")
tenant.id = "test-tenant"
tenant.name = "Test Workspace"
user._current_tenant = MagicMock()
user._current_tenant.id = "test-tenant"
@ -299,7 +298,7 @@ def test_to_domain_model(repository):
db_model.predecessor_node_id = "test-predecessor-id"
db_model.node_execution_id = "test-node-execution-id"
db_model.node_id = "test-node-id"
db_model.node_type = NodeType.START.value
db_model.node_type = NodeType.START
db_model.title = "Test Node"
db_model.inputs = json.dumps(inputs_dict)
db_model.process_data = json.dumps(process_data_dict)

View File

@ -118,7 +118,7 @@ class TestMetadataBugCompleteValidation:
# But would crash when trying to create MetadataArgs
with pytest.raises((ValueError, TypeError)):
MetadataArgs(**args)
MetadataArgs.model_validate(args)
def test_7_end_to_end_validation_layers(self):
"""Test all validation layers work together correctly."""
@ -131,7 +131,7 @@ class TestMetadataBugCompleteValidation:
valid_data = {"type": "string", "name": "test_metadata"}
# Should create valid Pydantic object
metadata_args = MetadataArgs(**valid_data)
metadata_args = MetadataArgs.model_validate(valid_data)
assert metadata_args.type == "string"
assert metadata_args.name == "test_metadata"

View File

@ -76,7 +76,7 @@ class TestMetadataNullableBug:
# Step 2: Try to create MetadataArgs with None values
# This should fail at Pydantic validation level
with pytest.raises((ValueError, TypeError)):
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
# Step 3: If we bypass Pydantic (simulating the bug scenario)
# Move this outside the request context to avoid Flask-Login issues

View File

@ -107,7 +107,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
@ -168,7 +168,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id

View File

@ -47,7 +47,8 @@ class TestDraftVariableSaver:
def test__should_variable_be_visible(self):
mock_session = MagicMock(spec=Session)
mock_user = Account(id=str(uuid.uuid4()))
mock_user = Account(name="test", email="test@example.com")
mock_user.id = str(uuid.uuid4())
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,