mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Merge branch 'feat/collaboration' into deploy/dev
This commit is contained in:
@ -143,7 +143,7 @@ class TestOAuthCallback:
|
||||
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
|
||||
account = MagicMock()
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.status = AccountStatus.ACTIVE
|
||||
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "jwt_access_token"
|
||||
@ -220,11 +220,11 @@ class TestOAuthCallback:
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
(AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
|
||||
(AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."),
|
||||
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
|
||||
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||
(
|
||||
AccountStatus.CLOSED.value,
|
||||
AccountStatus.CLOSED,
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
|
||||
),
|
||||
],
|
||||
@ -296,13 +296,13 @@ class TestOAuthCallback:
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = AccountStatus.PENDING.value
|
||||
mock_account.status = AccountStatus.PENDING
|
||||
mock_generate_account.return_value = mock_account
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
assert mock_account.status == AccountStatus.ACTIVE.value
|
||||
assert mock_account.status == AccountStatus.ACTIVE
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -352,7 +352,7 @@ class TestOAuthCallback:
|
||||
|
||||
# Create account with CLOSED status
|
||||
closed_account = MagicMock()
|
||||
closed_account.status = AccountStatus.CLOSED.value
|
||||
closed_account.status = AccountStatus.CLOSED
|
||||
closed_account.id = "123"
|
||||
closed_account.name = "Closed Account"
|
||||
mock_generate_account.return_value = closed_account
|
||||
|
||||
@ -0,0 +1,722 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVector,
|
||||
AlibabaCloudMySQLVectorConfig,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
|
||||
try:
|
||||
from mysql.connector import Error as MySQLError
|
||||
except ImportError:
|
||||
# Fallback for testing environments where mysql-connector-python might not be installed
|
||||
class MySQLError(Exception):
|
||||
def __init__(self, errno, msg):
|
||||
self.errno = errno
|
||||
self.msg = msg
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
charset="utf8mb4",
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
# Sample documents for testing
|
||||
self.sample_documents = [
|
||||
Document(
|
||||
page_content="This is a test document about AI.",
|
||||
metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"},
|
||||
),
|
||||
Document(
|
||||
page_content="Another document about machine learning.",
|
||||
metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"},
|
||||
),
|
||||
]
|
||||
|
||||
# Sample embeddings
|
||||
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test AlibabaCloudMySQLVector initialization."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor for vector support check
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"}, # Version check
|
||||
{"vector_support": True}, # Vector support check
|
||||
]
|
||||
|
||||
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert alibabacloud_mysql_vector.collection_name == self.collection_name
|
||||
assert alibabacloud_mysql_vector.table_name == self.collection_name.lower()
|
||||
assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql"
|
||||
assert alibabacloud_mysql_vector.distance_function == "cosine"
|
||||
assert alibabacloud_mysql_vector.pool is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_collection(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"}, # Version check
|
||||
{"vector_support": True}, # Vector support check
|
||||
]
|
||||
|
||||
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
alibabacloud_mysql_vector._create_collection(768)
|
||||
|
||||
# Verify SQL execution calls - should include table creation and index creation
|
||||
assert mock_cursor.execute.called
|
||||
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
AlibabaCloudMySQLVectorConfig(
|
||||
host="", # Empty host should raise error
|
||||
port=3306,
|
||||
user="test",
|
||||
password="test",
|
||||
database="test",
|
||||
max_connection=5,
|
||||
)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_success(self, mock_pool_class):
|
||||
"""Test successful vector support check."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
# Should not raise an exception
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
assert vector_store is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_failure(self, mock_pool_class):
|
||||
"""Test vector support check failure."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}]
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_function_error(self, mock_pool_class):
|
||||
"""Test vector support check with function not found error."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"}
|
||||
mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")]
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_documents(self, mock_redis, mock_pool_class):
|
||||
"""Test creating documents with embeddings."""
|
||||
# Setup mocks
|
||||
self._setup_mocks(mock_redis, mock_pool_class)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
result = vector_store.create(self.sample_documents, self.sample_embeddings)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "doc1" in result
|
||||
assert "doc2" in result
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_add_texts(self, mock_pool_class):
|
||||
"""Test adding texts to the vector store."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
result = vector_store.add_texts(self.sample_documents, self.sample_embeddings)
|
||||
|
||||
assert len(result) == 2
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_text_exists(self, mock_pool_class):
|
||||
"""Test checking if text exists."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"},
|
||||
{"vector_support": True},
|
||||
{"id": "doc1"}, # Text exists
|
||||
]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
exists = vector_store.text_exists("doc1")
|
||||
|
||||
assert exists
|
||||
# Check that the correct SQL was executed (last call after init)
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
last_call = execute_calls[-1]
|
||||
assert "SELECT id FROM" in last_call[0][0]
|
||||
assert last_call[0][1] == ("doc1",)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_text_not_exists(self, mock_pool_class):
|
||||
"""Test checking if text does not exist."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"},
|
||||
{"vector_support": True},
|
||||
None, # Text does not exist
|
||||
]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
exists = vector_store.text_exists("nonexistent")
|
||||
|
||||
assert not exists
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_get_by_ids(self, mock_pool_class):
|
||||
"""Test getting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"},
|
||||
{"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"},
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.get_by_ids(["doc1", "doc2"])
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert docs[1].page_content == "Test document 2"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_get_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test getting documents with empty ID list."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.get_by_ids([])
|
||||
|
||||
assert len(docs) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids(self, mock_pool_class):
|
||||
"""Test deleting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_ids(["doc1", "doc2"])
|
||||
|
||||
# Check that delete SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 1
|
||||
delete_call = delete_calls[0]
|
||||
assert "DELETE FROM" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ["doc1", "doc2"]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test deleting with empty ID list."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_ids([]) # Should not raise an exception
|
||||
|
||||
# Verify no delete SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
|
||||
"""Test deleting when table doesn't exist."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
# Simulate table doesn't exist error on delete
|
||||
|
||||
def execute_side_effect(*args, **kwargs):
|
||||
if "DELETE" in args[0]:
|
||||
raise MySQLError(errno=1146, msg="Table doesn't exist")
|
||||
|
||||
mock_cursor.execute.side_effect = execute_side_effect
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
# Should not raise an exception
|
||||
vector_store.delete_by_ids(["doc1"])
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_metadata_field(self, mock_pool_class):
|
||||
"""Test deleting documents by metadata field."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_metadata_field("document_id", "dataset1")
|
||||
|
||||
# Check that the correct SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 1
|
||||
delete_call = delete_calls[0]
|
||||
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ("$.document_id", "dataset1")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_cosine(self, mock_pool_class):
|
||||
"""Test vector search with cosine distance."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
|
||||
assert docs[0].metadata["distance"] == 0.1
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_euclidean(self, mock_pool_class):
|
||||
"""Test vector search with euclidean distance."""
|
||||
config = AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
distance_function="euclidean",
|
||||
)
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_with_filter(self, mock_pool_class):
|
||||
"""Test vector search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter([])
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"])
|
||||
|
||||
# Verify the SQL contains the WHERE clause for filtering
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)]
|
||||
assert len(search_calls) > 0
|
||||
search_call = search_calls[0]
|
||||
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
|
||||
"""Test vector search with score threshold."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{
|
||||
"meta": json.dumps({"doc_id": "doc1", "source": "test"}),
|
||||
"text": "High similarity document",
|
||||
"distance": 0.1, # High similarity (score = 0.9)
|
||||
},
|
||||
{
|
||||
"meta": json.dumps({"doc_id": "doc2", "source": "test"}),
|
||||
"text": "Low similarity document",
|
||||
"distance": 0.8, # Low similarity (score = 0.2)
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5)
|
||||
|
||||
# Only the high similarity document should be returned
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "High similarity document"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
|
||||
"""Test vector search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text(self, mock_pool_class):
|
||||
"""Test full-text search."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{
|
||||
"meta": {"doc_id": "doc1", "source": "test"},
|
||||
"text": "This document contains machine learning content",
|
||||
"score": 1.5,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.search_by_full_text("machine learning", top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "This document contains machine learning content"
|
||||
assert docs[0].metadata["score"] == 1.5
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text_with_filter(self, mock_pool_class):
|
||||
"""Test full-text search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter([])
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"])
|
||||
|
||||
# Verify the SQL contains the AND clause for filtering
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
search_calls = [call for call in execute_calls if "MATCH" in str(call)]
|
||||
assert len(search_calls) > 0
|
||||
search_call = search_calls[0]
|
||||
assert "AND JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
|
||||
"""Test full-text search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_collection(self, mock_pool_class):
|
||||
"""Test deleting the entire collection."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete()
|
||||
|
||||
# Check that DROP TABLE SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)]
|
||||
assert len(drop_calls) == 1
|
||||
drop_call = drop_calls[0]
|
||||
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_unsupported_distance_function(self, mock_pool_class):
|
||||
"""Test that Pydantic validation rejects unsupported distance functions."""
|
||||
# Test that creating config with unsupported distance function raises ValidationError
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"]
|
||||
)
|
||||
|
||||
# The error should be related to validation
|
||||
assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value)
|
||||
|
||||
def _setup_mocks(self, mock_redis, mock_pool_class):
|
||||
"""Helper method to setup common mocks."""
|
||||
# Mock Redis operations
|
||||
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -18,7 +18,7 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
mocked_firecrawl = {
|
||||
"id": "test",
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
|
||||
assert job_id is not None
|
||||
|
||||
@ -69,7 +69,7 @@ def test_notion_page(mocker):
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
|
||||
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
|
||||
|
||||
page_docs = extractor._load_data_as_documents(page_id, "page")
|
||||
assert len(page_docs) == 1
|
||||
@ -84,7 +84,7 @@ def test_notion_database(mocker):
|
||||
"results": [_generate_page(i) for i in page_title_list],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
|
||||
database_docs = extractor._load_data_as_documents(database_id, "database")
|
||||
assert len(database_docs) == 1
|
||||
content = _remove_multiple_new_lines(database_docs[0].page_content)
|
||||
|
||||
@ -140,7 +140,7 @@ class TestCeleryWorkflowExecutionRepository:
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify no task tracking occurs (no _pending_saves attribute)
|
||||
|
||||
@ -149,7 +149,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
|
||||
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify execution is cached
|
||||
|
||||
@ -145,12 +145,12 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node-id"
|
||||
db_model.node_type = NodeType.LLM.value
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps({"value": "inputs"})
|
||||
db_model.process_data = json.dumps({"value": "process_data"})
|
||||
db_model.outputs = json.dumps({"value": "outputs"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.0
|
||||
db_model.execution_metadata = "{}"
|
||||
|
||||
@ -14,7 +14,13 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormOption,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@ -306,3 +312,174 @@ class TestProviderConfiguration:
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables from credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 2
|
||||
assert "api_key" in secret_variables
|
||||
assert "secret_token" in secret_variables
|
||||
assert "model_name" not in secret_variables
|
||||
|
||||
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables when no secret input fields exist"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.SELECT,
|
||||
required=True,
|
||||
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
def test_extract_secret_variables_empty_list(self, provider_configuration):
|
||||
"""Test extracting secret variables from empty credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
@patch("core.entities.provider_configuration.encrypter")
|
||||
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
|
||||
"""Test obfuscating credentials with secret variables"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"api_key": "sk-1234567890abcdef",
|
||||
"model_name": "gpt-4",
|
||||
"secret_token": "secret_value_123",
|
||||
"temperature": "0.7",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated["api_key"] == "***cdef"
|
||||
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
|
||||
assert obfuscated["secret_token"] == "***_123"
|
||||
assert obfuscated["temperature"] == "0.7" # Not obfuscated
|
||||
|
||||
# Verify encrypter was called for secret fields only
|
||||
assert mock_encrypter.obfuscated_token.call_count == 2
|
||||
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
|
||||
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
|
||||
|
||||
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
|
||||
"""Test obfuscating credentials when no secret variables exist"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"model_name": "gpt-4",
|
||||
"temperature": "0.7",
|
||||
"max_tokens": "1000",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="max_tokens",
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == credentials # No changes expected
|
||||
|
||||
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
|
||||
"""Test obfuscating empty credentials"""
|
||||
# Arrange
|
||||
credentials = {}
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == {}
|
||||
|
||||
@ -147,7 +147,7 @@ class TestRedisChannel:
|
||||
"""Test deserializing an abort command."""
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
abort_data = {"command_type": CommandType.ABORT.value}
|
||||
abort_data = {"command_type": CommandType.ABORT}
|
||||
command = channel._deserialize_command(abort_data)
|
||||
|
||||
assert isinstance(command, AbortCommand)
|
||||
@ -158,7 +158,7 @@ class TestRedisChannel:
|
||||
channel = RedisChannel(MagicMock(), "test:key")
|
||||
|
||||
# For now, only ABORT is supported, but test generic handling
|
||||
generic_data = {"command_type": CommandType.ABORT.value}
|
||||
generic_data = {"command_type": CommandType.ABORT}
|
||||
command = channel._deserialize_command(generic_data)
|
||||
|
||||
assert command is not None
|
||||
|
||||
@ -56,8 +56,8 @@ def test_mock_iteration_node_preserves_config():
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
@ -117,8 +117,8 @@ def test_mock_loop_node_preserves_config():
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ class TestRedisStopIntegration:
|
||||
# Verify the command data
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT.value
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "Test stop"
|
||||
|
||||
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
|
||||
@ -122,7 +122,7 @@ class TestRedisStopIntegration:
|
||||
# Verify serialized command
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT.value
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "User requested stop"
|
||||
|
||||
# Check expire was set
|
||||
@ -137,9 +137,7 @@ class TestRedisStopIntegration:
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock command data
|
||||
abort_command_json = json.dumps(
|
||||
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
|
||||
)
|
||||
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
|
||||
|
||||
# Mock pipeline execute to return commands
|
||||
mock_pipeline.execute.return_value = [
|
||||
|
||||
@ -87,7 +87,7 @@ def test_overwrite_string_variable():
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.OVER_WRITE.value,
|
||||
"write_mode": WriteMode.OVER_WRITE,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
@ -189,7 +189,7 @@ def test_append_variable_to_array():
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.APPEND.value,
|
||||
"write_mode": WriteMode.APPEND,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
@ -282,7 +282,7 @@ def test_clear_array():
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.CLEAR.value,
|
||||
"write_mode": WriteMode.CLEAR,
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
}
|
||||
|
||||
@ -298,7 +298,7 @@ def test_to_domain_model(repository):
|
||||
db_model.predecessor_node_id = "test-predecessor-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = NodeType.START.value
|
||||
db_model.node_type = NodeType.START
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps(inputs_dict)
|
||||
db_model.process_data = json.dumps(process_data_dict)
|
||||
|
||||
@ -107,7 +107,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
@ -168,7 +168,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
|
||||
Reference in New Issue
Block a user