mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
Merge branch 'main' into feat/mcp-06-18
This commit is contained in:
@ -143,7 +143,7 @@ class TestOAuthCallback:
|
||||
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
|
||||
account = MagicMock()
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.status = AccountStatus.ACTIVE
|
||||
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "jwt_access_token"
|
||||
@ -220,11 +220,11 @@ class TestOAuthCallback:
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
(AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
|
||||
(AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."),
|
||||
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
|
||||
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||
(
|
||||
AccountStatus.CLOSED.value,
|
||||
AccountStatus.CLOSED,
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
|
||||
),
|
||||
],
|
||||
@ -296,13 +296,13 @@ class TestOAuthCallback:
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = AccountStatus.PENDING.value
|
||||
mock_account.status = AccountStatus.PENDING
|
||||
mock_generate_account.return_value = mock_account
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
assert mock_account.status == AccountStatus.ACTIVE.value
|
||||
assert mock_account.status == AccountStatus.ACTIVE
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -352,7 +352,7 @@ class TestOAuthCallback:
|
||||
|
||||
# Create account with CLOSED status
|
||||
closed_account = MagicMock()
|
||||
closed_account.status = AccountStatus.CLOSED.value
|
||||
closed_account.status = AccountStatus.CLOSED
|
||||
closed_account.id = "123"
|
||||
closed_account.name = "Closed Account"
|
||||
mock_generate_account.return_value = closed_account
|
||||
|
||||
@ -60,7 +60,7 @@ class TestAccountInitialization:
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
@ -77,7 +77,7 @@ class TestAccountInitialization:
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
with pytest.raises(AccountNotInitializedError):
|
||||
protected_view()
|
||||
|
||||
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
|
||||
return "member_added"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = add_member()
|
||||
|
||||
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
add_member()
|
||||
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
|
||||
|
||||
# Test 1: Should reject when source is datasets
|
||||
with app.test_request_context("/?source=datasets"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
upload_document()
|
||||
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
|
||||
|
||||
# Test 2: Should allow when source is not datasets
|
||||
with app.test_request_context("/?source=other"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = upload_document()
|
||||
assert result == "document_uploaded"
|
||||
@ -239,7 +239,7 @@ class TestRateLimiting:
|
||||
return "knowledge_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
@ -271,7 +271,7 @@ class TestRateLimiting:
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
|
||||
@ -0,0 +1,722 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVector,
|
||||
AlibabaCloudMySQLVectorConfig,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
|
||||
try:
|
||||
from mysql.connector import Error as MySQLError
|
||||
except ImportError:
|
||||
# Fallback for testing environments where mysql-connector-python might not be installed
|
||||
class MySQLError(Exception):
|
||||
def __init__(self, errno, msg):
|
||||
self.errno = errno
|
||||
self.msg = msg
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
charset="utf8mb4",
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
# Sample documents for testing
|
||||
self.sample_documents = [
|
||||
Document(
|
||||
page_content="This is a test document about AI.",
|
||||
metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"},
|
||||
),
|
||||
Document(
|
||||
page_content="Another document about machine learning.",
|
||||
metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"},
|
||||
),
|
||||
]
|
||||
|
||||
# Sample embeddings
|
||||
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test AlibabaCloudMySQLVector initialization."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor for vector support check
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"}, # Version check
|
||||
{"vector_support": True}, # Vector support check
|
||||
]
|
||||
|
||||
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert alibabacloud_mysql_vector.collection_name == self.collection_name
|
||||
assert alibabacloud_mysql_vector.table_name == self.collection_name.lower()
|
||||
assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql"
|
||||
assert alibabacloud_mysql_vector.distance_function == "cosine"
|
||||
assert alibabacloud_mysql_vector.pool is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_collection(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"}, # Version check
|
||||
{"vector_support": True}, # Vector support check
|
||||
]
|
||||
|
||||
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
alibabacloud_mysql_vector._create_collection(768)
|
||||
|
||||
# Verify SQL execution calls - should include table creation and index creation
|
||||
assert mock_cursor.execute.called
|
||||
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
AlibabaCloudMySQLVectorConfig(
|
||||
host="", # Empty host should raise error
|
||||
port=3306,
|
||||
user="test",
|
||||
password="test",
|
||||
database="test",
|
||||
max_connection=5,
|
||||
)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_success(self, mock_pool_class):
|
||||
"""Test successful vector support check."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
# Should not raise an exception
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
assert vector_store is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_failure(self, mock_pool_class):
|
||||
"""Test vector support check failure."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}]
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_function_error(self, mock_pool_class):
|
||||
"""Test vector support check with function not found error."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"}
|
||||
mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")]
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_documents(self, mock_redis, mock_pool_class):
|
||||
"""Test creating documents with embeddings."""
|
||||
# Setup mocks
|
||||
self._setup_mocks(mock_redis, mock_pool_class)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
result = vector_store.create(self.sample_documents, self.sample_embeddings)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "doc1" in result
|
||||
assert "doc2" in result
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_add_texts(self, mock_pool_class):
|
||||
"""Test adding texts to the vector store."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
result = vector_store.add_texts(self.sample_documents, self.sample_embeddings)
|
||||
|
||||
assert len(result) == 2
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_text_exists(self, mock_pool_class):
|
||||
"""Test checking if text exists."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"},
|
||||
{"vector_support": True},
|
||||
{"id": "doc1"}, # Text exists
|
||||
]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
exists = vector_store.text_exists("doc1")
|
||||
|
||||
assert exists
|
||||
# Check that the correct SQL was executed (last call after init)
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
last_call = execute_calls[-1]
|
||||
assert "SELECT id FROM" in last_call[0][0]
|
||||
assert last_call[0][1] == ("doc1",)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_text_not_exists(self, mock_pool_class):
|
||||
"""Test checking if text does not exist."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"},
|
||||
{"vector_support": True},
|
||||
None, # Text does not exist
|
||||
]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
exists = vector_store.text_exists("nonexistent")
|
||||
|
||||
assert not exists
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_get_by_ids(self, mock_pool_class):
|
||||
"""Test getting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"},
|
||||
{"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"},
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.get_by_ids(["doc1", "doc2"])
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert docs[1].page_content == "Test document 2"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_get_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test getting documents with empty ID list."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.get_by_ids([])
|
||||
|
||||
assert len(docs) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids(self, mock_pool_class):
|
||||
"""Test deleting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_ids(["doc1", "doc2"])
|
||||
|
||||
# Check that delete SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 1
|
||||
delete_call = delete_calls[0]
|
||||
assert "DELETE FROM" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ["doc1", "doc2"]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test deleting with empty ID list."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_ids([]) # Should not raise an exception
|
||||
|
||||
# Verify no delete SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
|
||||
"""Test deleting when table doesn't exist."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
# Simulate table doesn't exist error on delete
|
||||
|
||||
def execute_side_effect(*args, **kwargs):
|
||||
if "DELETE" in args[0]:
|
||||
raise MySQLError(errno=1146, msg="Table doesn't exist")
|
||||
|
||||
mock_cursor.execute.side_effect = execute_side_effect
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
# Should not raise an exception
|
||||
vector_store.delete_by_ids(["doc1"])
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_metadata_field(self, mock_pool_class):
|
||||
"""Test deleting documents by metadata field."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_metadata_field("document_id", "dataset1")
|
||||
|
||||
# Check that the correct SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 1
|
||||
delete_call = delete_calls[0]
|
||||
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ("$.document_id", "dataset1")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_cosine(self, mock_pool_class):
|
||||
"""Test vector search with cosine distance."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
|
||||
assert docs[0].metadata["distance"] == 0.1
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_euclidean(self, mock_pool_class):
|
||||
"""Test vector search with euclidean distance."""
|
||||
config = AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
distance_function="euclidean",
|
||||
)
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_with_filter(self, mock_pool_class):
|
||||
"""Test vector search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter([])
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"])
|
||||
|
||||
# Verify the SQL contains the WHERE clause for filtering
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)]
|
||||
assert len(search_calls) > 0
|
||||
search_call = search_calls[0]
|
||||
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
|
||||
"""Test vector search with score threshold."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{
|
||||
"meta": json.dumps({"doc_id": "doc1", "source": "test"}),
|
||||
"text": "High similarity document",
|
||||
"distance": 0.1, # High similarity (score = 0.9)
|
||||
},
|
||||
{
|
||||
"meta": json.dumps({"doc_id": "doc2", "source": "test"}),
|
||||
"text": "Low similarity document",
|
||||
"distance": 0.8, # Low similarity (score = 0.2)
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5)
|
||||
|
||||
# Only the high similarity document should be returned
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "High similarity document"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
|
||||
"""Test vector search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text(self, mock_pool_class):
|
||||
"""Test full-text search."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{
|
||||
"meta": {"doc_id": "doc1", "source": "test"},
|
||||
"text": "This document contains machine learning content",
|
||||
"score": 1.5,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.search_by_full_text("machine learning", top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "This document contains machine learning content"
|
||||
assert docs[0].metadata["score"] == 1.5
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text_with_filter(self, mock_pool_class):
|
||||
"""Test full-text search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter([])
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"])
|
||||
|
||||
# Verify the SQL contains the AND clause for filtering
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
search_calls = [call for call in execute_calls if "MATCH" in str(call)]
|
||||
assert len(search_calls) > 0
|
||||
search_call = search_calls[0]
|
||||
assert "AND JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
|
||||
"""Test full-text search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_collection(self, mock_pool_class):
|
||||
"""Test deleting the entire collection."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete()
|
||||
|
||||
# Check that DROP TABLE SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)]
|
||||
assert len(drop_calls) == 1
|
||||
drop_call = drop_calls[0]
|
||||
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_unsupported_distance_function(self, mock_pool_class):
|
||||
"""Test that Pydantic validation rejects unsupported distance functions."""
|
||||
# Test that creating config with unsupported distance function raises ValidationError
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"]
|
||||
)
|
||||
|
||||
# The error should be related to validation
|
||||
assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value)
|
||||
|
||||
def _setup_mocks(self, mock_redis, mock_pool_class):
|
||||
"""Helper method to setup common mocks."""
|
||||
# Mock Redis operations
|
||||
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -11,8 +11,8 @@ def test_default_value():
|
||||
config = valid_config.copy()
|
||||
del config[key]
|
||||
with pytest.raises(ValidationError) as e:
|
||||
MilvusConfig(**config)
|
||||
MilvusConfig.model_validate(config)
|
||||
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
|
||||
|
||||
config = MilvusConfig(**valid_config)
|
||||
config = MilvusConfig.model_validate(valid_config)
|
||||
assert config.database == "default"
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import os
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
@ -18,7 +20,7 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
mocked_firecrawl = {
|
||||
"id": "test",
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
|
||||
assert job_id is not None
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from unittest import mock
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor import notion_extractor
|
||||
|
||||
user_id = "user1"
|
||||
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
|
||||
return text.strip()
|
||||
|
||||
|
||||
def test_notion_page(mocker):
|
||||
def test_notion_page(mocker: MockerFixture):
|
||||
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
||||
mocked_notion_page = {
|
||||
"object": "list",
|
||||
@ -69,7 +71,7 @@ def test_notion_page(mocker):
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
|
||||
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
|
||||
|
||||
page_docs = extractor._load_data_as_documents(page_id, "page")
|
||||
assert len(page_docs) == 1
|
||||
@ -77,14 +79,14 @@ def test_notion_page(mocker):
|
||||
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
||||
|
||||
|
||||
def test_notion_database(mocker):
|
||||
def test_notion_database(mocker: MockerFixture):
|
||||
page_title_list = ["page1", "page2", "page3"]
|
||||
mocked_notion_database = {
|
||||
"object": "list",
|
||||
"results": [_generate_page(i) for i in page_title_list],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
|
||||
database_docs = extractor._load_data_as_documents(database_id, "database")
|
||||
assert len(database_docs) == 1
|
||||
content = _remove_multiple_new_lines(database_docs[0].page_content)
|
||||
|
||||
@ -140,7 +140,7 @@ class TestCeleryWorkflowExecutionRepository:
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify no task tracking occurs (no _pending_saves attribute)
|
||||
|
||||
@ -149,7 +149,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
|
||||
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify execution is cached
|
||||
|
||||
@ -145,12 +145,12 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node-id"
|
||||
db_model.node_type = NodeType.LLM.value
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps({"value": "inputs"})
|
||||
db_model.process_data = json.dumps({"value": "process_data"})
|
||||
db_model.outputs = json.dumps({"value": "outputs"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.0
|
||||
db_model.execution_metadata = "{}"
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.model_manager import LBModelManager
|
||||
@ -39,7 +40,7 @@ def lb_model_manager():
|
||||
return lb_model_manager
|
||||
|
||||
|
||||
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
||||
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
||||
# initialize redis client
|
||||
redis_client.initialize(redis.Redis())
|
||||
|
||||
|
||||
@ -14,7 +14,13 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormOption,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@ -306,3 +312,174 @@ class TestProviderConfiguration:
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables from credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 2
|
||||
assert "api_key" in secret_variables
|
||||
assert "secret_token" in secret_variables
|
||||
assert "model_name" not in secret_variables
|
||||
|
||||
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables when no secret input fields exist"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.SELECT,
|
||||
required=True,
|
||||
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
def test_extract_secret_variables_empty_list(self, provider_configuration):
|
||||
"""Test extracting secret variables from empty credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
@patch("core.entities.provider_configuration.encrypter")
|
||||
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
|
||||
"""Test obfuscating credentials with secret variables"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"api_key": "sk-1234567890abcdef",
|
||||
"model_name": "gpt-4",
|
||||
"secret_token": "secret_value_123",
|
||||
"temperature": "0.7",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated["api_key"] == "***cdef"
|
||||
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
|
||||
assert obfuscated["secret_token"] == "***_123"
|
||||
assert obfuscated["temperature"] == "0.7" # Not obfuscated
|
||||
|
||||
# Verify encrypter was called for secret fields only
|
||||
assert mock_encrypter.obfuscated_token.call_count == 2
|
||||
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
|
||||
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
|
||||
|
||||
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
|
||||
"""Test obfuscating credentials when no secret variables exist"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"model_name": "gpt-4",
|
||||
"temperature": "0.7",
|
||||
"max_tokens": "1000",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="max_tokens",
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == credentials # No changes expected
|
||||
|
||||
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
|
||||
"""Test obfuscating empty credentials"""
|
||||
# Arrange
|
||||
credentials = {}
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == {}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker):
|
||||
def mock_provider_entity(mocker: MockerFixture):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
mock_entity.model_credential_schema = mocker.Mock()
|
||||
mock_entity.model_credential_schema.credential_form_schemas = []
|
||||
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||
provider_credential_schema = mocker.Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.provider_credential_schema = provider_credential_schema
|
||||
|
||||
model_credential_schema = mocker.Mock()
|
||||
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.model_credential_schema = model_credential_schema
|
||||
|
||||
return mock_entity
|
||||
|
||||
|
||||
def test__to_model_settings(mocker, mock_provider_entity):
|
||||
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
|
||||
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
|
||||
@ -147,7 +147,7 @@ class TestRedisChannel:
|
||||
"""Test deserializing an abort command."""
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
abort_data = {"command_type": CommandType.ABORT.value}
|
||||
abort_data = {"command_type": CommandType.ABORT}
|
||||
command = channel._deserialize_command(abort_data)
|
||||
|
||||
assert isinstance(command, AbortCommand)
|
||||
@ -158,7 +158,7 @@ class TestRedisChannel:
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
# For now, only ABORT is supported, but test generic handling
|
||||
generic_data = {"command_type": CommandType.ABORT.value}
|
||||
generic_data = {"command_type": CommandType.ABORT}
|
||||
command = channel._deserialize_command(generic_data)
|
||||
|
||||
assert command is not None
|
||||
|
||||
@ -56,8 +56,8 @@ def test_mock_iteration_node_preserves_config():
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
@ -117,8 +117,8 @@ def test_mock_loop_node_preserves_config():
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ class TestRedisStopIntegration:
|
||||
# Verify the command data
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT.value
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "Test stop"
|
||||
|
||||
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
|
||||
@ -122,7 +122,7 @@ class TestRedisStopIntegration:
|
||||
# Verify serialized command
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT.value
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "User requested stop"
|
||||
|
||||
# Check expire was set
|
||||
@ -137,9 +137,7 @@ class TestRedisStopIntegration:
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock command data
|
||||
abort_command_json = json.dumps(
|
||||
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
|
||||
)
|
||||
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
|
||||
|
||||
# Mock pipeline execute to return commands
|
||||
mock_pipeline.execute.return_value = [
|
||||
|
||||
@ -35,7 +35,7 @@ def list_operator_node():
|
||||
"extract_by": ExtractConfig(enabled=False, serial="1"),
|
||||
"title": "Test Title",
|
||||
}
|
||||
node_data = ListOperatorNodeData(**config)
|
||||
node_data = ListOperatorNodeData.model_validate(config)
|
||||
node_config = {
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
|
||||
@ -17,7 +17,7 @@ def test_init_question_classifier_node_data():
|
||||
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
|
||||
}
|
||||
|
||||
node_data = QuestionClassifierNodeData(**data)
|
||||
node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
assert node_data.query_variable_selector == ["id", "name"]
|
||||
assert node_data.model.provider == "openai"
|
||||
@ -49,7 +49,7 @@ def test_init_question_classifier_node_data_without_vision_config():
|
||||
},
|
||||
}
|
||||
|
||||
node_data = QuestionClassifierNodeData(**data)
|
||||
node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
assert node_data.query_variable_selector == ["id", "name"]
|
||||
assert node_data.model.provider == "openai"
|
||||
|
||||
@ -87,7 +87,7 @@ def test_overwrite_string_variable():
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.OVER_WRITE.value,
|
||||
"write_mode": WriteMode.OVER_WRITE,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
@ -189,7 +189,7 @@ def test_append_variable_to_array():
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.APPEND.value,
|
||||
"write_mode": WriteMode.APPEND,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
@ -282,7 +282,7 @@ def test_clear_array():
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.CLEAR.value,
|
||||
"write_mode": WriteMode.CLEAR,
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@ class TestSystemVariableSerialization:
|
||||
def test_basic_deserialization(self):
|
||||
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
|
||||
# Test with complete data
|
||||
system_var = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA)
|
||||
|
||||
# Verify all fields are correctly mapped
|
||||
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
|
||||
@ -59,7 +59,7 @@ class TestSystemVariableSerialization:
|
||||
assert system_var.files == []
|
||||
|
||||
# Test with minimal data (only required fields)
|
||||
minimal_var = SystemVariable(**VALID_BASE_DATA)
|
||||
minimal_var = SystemVariable.model_validate(VALID_BASE_DATA)
|
||||
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
|
||||
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
|
||||
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
|
||||
@ -75,12 +75,12 @@ class TestSystemVariableSerialization:
|
||||
|
||||
# Test workflow_run_id only (preferred alias)
|
||||
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||
system_var1 = SystemVariable(**data_run_id)
|
||||
system_var1 = SystemVariable.model_validate(data_run_id)
|
||||
assert system_var1.workflow_execution_id == workflow_id
|
||||
|
||||
# Test workflow_execution_id only (direct field name)
|
||||
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||
system_var2 = SystemVariable(**data_execution_id)
|
||||
system_var2 = SystemVariable.model_validate(data_execution_id)
|
||||
assert system_var2.workflow_execution_id == workflow_id
|
||||
|
||||
# Test both present - workflow_run_id should take precedence
|
||||
@ -89,17 +89,17 @@ class TestSystemVariableSerialization:
|
||||
"workflow_execution_id": "should-be-ignored",
|
||||
"workflow_run_id": workflow_id,
|
||||
}
|
||||
system_var3 = SystemVariable(**data_both)
|
||||
system_var3 = SystemVariable.model_validate(data_both)
|
||||
assert system_var3.workflow_execution_id == workflow_id
|
||||
|
||||
# Test neither present - should be None
|
||||
system_var4 = SystemVariable(**VALID_BASE_DATA)
|
||||
system_var4 = SystemVariable.model_validate(VALID_BASE_DATA)
|
||||
assert system_var4.workflow_execution_id is None
|
||||
|
||||
def test_serialization_round_trip(self):
|
||||
"""Test that serialize → deserialize produces the same result with alias handling."""
|
||||
# Create original SystemVariable
|
||||
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
|
||||
|
||||
# Serialize to dict
|
||||
serialized = original.model_dump(mode="json")
|
||||
@ -110,7 +110,7 @@ class TestSystemVariableSerialization:
|
||||
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
|
||||
# Deserialize back
|
||||
deserialized = SystemVariable(**serialized)
|
||||
deserialized = SystemVariable.model_validate(serialized)
|
||||
|
||||
# Verify all fields match after round-trip
|
||||
assert deserialized.user_id == original.user_id
|
||||
@ -125,7 +125,7 @@ class TestSystemVariableSerialization:
|
||||
def test_json_round_trip(self):
|
||||
"""Test JSON serialization/deserialization consistency with proper structure."""
|
||||
# Create original SystemVariable
|
||||
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
|
||||
|
||||
# Serialize to JSON string
|
||||
json_str = original.model_dump_json()
|
||||
@ -137,7 +137,7 @@ class TestSystemVariableSerialization:
|
||||
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
|
||||
# Deserialize from JSON data
|
||||
deserialized = SystemVariable(**json_data)
|
||||
deserialized = SystemVariable.model_validate(json_data)
|
||||
|
||||
# Verify key fields match after JSON round-trip
|
||||
assert deserialized.workflow_execution_id == original.workflow_execution_id
|
||||
@ -149,13 +149,13 @@ class TestSystemVariableSerialization:
|
||||
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
|
||||
# Test with empty files list
|
||||
data_empty = {**VALID_BASE_DATA, "files": []}
|
||||
system_var_empty = SystemVariable(**data_empty)
|
||||
system_var_empty = SystemVariable.model_validate(data_empty)
|
||||
assert system_var_empty.files == []
|
||||
|
||||
# Test with single File object
|
||||
test_file = create_test_file()
|
||||
data_single = {**VALID_BASE_DATA, "files": [test_file]}
|
||||
system_var_single = SystemVariable(**data_single)
|
||||
system_var_single = SystemVariable.model_validate(data_single)
|
||||
assert len(system_var_single.files) == 1
|
||||
assert system_var_single.files[0].filename == "test.txt"
|
||||
assert system_var_single.files[0].tenant_id == "test-tenant-id"
|
||||
@ -179,14 +179,14 @@ class TestSystemVariableSerialization:
|
||||
)
|
||||
|
||||
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
|
||||
system_var_multiple = SystemVariable(**data_multiple)
|
||||
system_var_multiple = SystemVariable.model_validate(data_multiple)
|
||||
assert len(system_var_multiple.files) == 2
|
||||
assert system_var_multiple.files[0].filename == "doc1.txt"
|
||||
assert system_var_multiple.files[1].filename == "image.jpg"
|
||||
|
||||
# Verify files field serialization/deserialization
|
||||
serialized = system_var_multiple.model_dump(mode="json")
|
||||
deserialized = SystemVariable(**serialized)
|
||||
deserialized = SystemVariable.model_validate(serialized)
|
||||
assert len(deserialized.files) == 2
|
||||
assert deserialized.files[0].filename == "doc1.txt"
|
||||
assert deserialized.files[1].filename == "image.jpg"
|
||||
@ -197,7 +197,7 @@ class TestSystemVariableSerialization:
|
||||
|
||||
# Create with workflow_run_id (alias)
|
||||
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||
system_var = SystemVariable(**data_with_alias)
|
||||
system_var = SystemVariable.model_validate(data_with_alias)
|
||||
|
||||
# Serialize and verify alias is used
|
||||
serialized = system_var.model_dump()
|
||||
@ -205,7 +205,7 @@ class TestSystemVariableSerialization:
|
||||
assert "workflow_execution_id" not in serialized
|
||||
|
||||
# Deserialize and verify field mapping
|
||||
deserialized = SystemVariable(**serialized)
|
||||
deserialized = SystemVariable.model_validate(serialized)
|
||||
assert deserialized.workflow_execution_id == workflow_id
|
||||
|
||||
# Test JSON serialization path
|
||||
@ -213,7 +213,7 @@ class TestSystemVariableSerialization:
|
||||
assert json_serialized["workflow_run_id"] == workflow_id
|
||||
assert "workflow_execution_id" not in json_serialized
|
||||
|
||||
json_deserialized = SystemVariable(**json_serialized)
|
||||
json_deserialized = SystemVariable.model_validate(json_serialized)
|
||||
assert json_deserialized.workflow_execution_id == workflow_id
|
||||
|
||||
def test_model_validator_serialization_logic(self):
|
||||
@ -222,7 +222,7 @@ class TestSystemVariableSerialization:
|
||||
|
||||
# Test direct instantiation with workflow_execution_id (should work)
|
||||
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||
system_var1 = SystemVariable(**data1)
|
||||
system_var1 = SystemVariable.model_validate(data1)
|
||||
assert system_var1.workflow_execution_id == workflow_id
|
||||
|
||||
# Test serialization of the above (should use alias)
|
||||
@ -236,7 +236,7 @@ class TestSystemVariableSerialization:
|
||||
"workflow_execution_id": "should-be-removed",
|
||||
"workflow_run_id": workflow_id,
|
||||
}
|
||||
system_var2 = SystemVariable(**data2)
|
||||
system_var2 = SystemVariable.model_validate(data2)
|
||||
assert system_var2.workflow_execution_id == workflow_id
|
||||
|
||||
# Verify serialization consistency
|
||||
|
||||
@ -11,7 +11,7 @@ class TestExtractTenantId:
|
||||
def test_extract_tenant_id_from_account_with_tenant(self):
|
||||
"""Test extracting tenant_id from Account with current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
account = Account(name="test", email="test@example.com")
|
||||
# Mock the current_tenant_id property
|
||||
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
||||
|
||||
@ -21,7 +21,7 @@ class TestExtractTenantId:
|
||||
def test_extract_tenant_id_from_account_without_tenant(self):
|
||||
"""Test extracting tenant_id from Account without current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
account = Account(name="test", email="test@example.com")
|
||||
account._current_tenant = None
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
|
||||
@ -59,12 +59,11 @@ def session():
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Create a user instance for testing."""
|
||||
user = Account()
|
||||
user = Account(name="test", email="test@example.com")
|
||||
user.id = "test-user-id"
|
||||
|
||||
tenant = Tenant()
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.id = "test-tenant"
|
||||
tenant.name = "Test Workspace"
|
||||
user._current_tenant = MagicMock()
|
||||
user._current_tenant.id = "test-tenant"
|
||||
|
||||
@ -299,7 +298,7 @@ def test_to_domain_model(repository):
|
||||
db_model.predecessor_node_id = "test-predecessor-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = NodeType.START.value
|
||||
db_model.node_type = NodeType.START
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps(inputs_dict)
|
||||
db_model.process_data = json.dumps(process_data_dict)
|
||||
|
||||
@ -118,7 +118,7 @@ class TestMetadataBugCompleteValidation:
|
||||
|
||||
# But would crash when trying to create MetadataArgs
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(**args)
|
||||
MetadataArgs.model_validate(args)
|
||||
|
||||
def test_7_end_to_end_validation_layers(self):
|
||||
"""Test all validation layers work together correctly."""
|
||||
@ -131,7 +131,7 @@ class TestMetadataBugCompleteValidation:
|
||||
valid_data = {"type": "string", "name": "test_metadata"}
|
||||
|
||||
# Should create valid Pydantic object
|
||||
metadata_args = MetadataArgs(**valid_data)
|
||||
metadata_args = MetadataArgs.model_validate(valid_data)
|
||||
assert metadata_args.type == "string"
|
||||
assert metadata_args.name == "test_metadata"
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ class TestMetadataNullableBug:
|
||||
# Step 2: Try to create MetadataArgs with None values
|
||||
# This should fail at Pydantic validation level
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
metadata_args = MetadataArgs(**args)
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
# Step 3: If we bypass Pydantic (simulating the bug scenario)
|
||||
# Move this outside the request context to avoid Flask-Login issues
|
||||
|
||||
@ -107,7 +107,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
@ -168,7 +168,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
|
||||
@ -47,7 +47,8 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = Account(id=str(uuid.uuid4()))
|
||||
mock_user = Account(name="test", email="test@example.com")
|
||||
mock_user.id = str(uuid.uuid4())
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
|
||||
Reference in New Issue
Block a user