mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
refactor: core/tools, agent, callback_handler, encrypter, llm_generator, plugin, inner_api (#34205)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@ -64,18 +64,18 @@ class TestGetActiveAccount:
|
||||
def test_returns_active_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "active"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
result = _get_active_account("user@example.com")
|
||||
|
||||
assert result is mock_account
|
||||
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_inactive_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "banned"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
result = _get_active_account("banned@example.com")
|
||||
|
||||
@ -83,7 +83,7 @@ class TestGetActiveAccount:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_nonexistent_email(self, mock_db):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
result = _get_active_account("missing@example.com")
|
||||
|
||||
@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "yaml-data"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
with app.test_request_context("?include_secret=false"):
|
||||
|
||||
@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool:
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
session.scalar.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
|
||||
@ -114,13 +114,9 @@ class TestOnToolEnd:
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
|
||||
@ -138,13 +134,9 @@ class TestOnToolEnd:
|
||||
"dataset_id": "dataset-1",
|
||||
}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_empty_documents(self, handler):
|
||||
|
||||
@ -38,13 +38,13 @@ class TestObfuscatedToken:
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
@ -52,10 +52,10 @@ class TestEncryptToken:
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
mock_query.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
@ -119,7 +119,7 @@ class TestGetDecryptDecoding:
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration:
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
@ -181,12 +181,12 @@ class TestSecurity:
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
@ -205,13 +205,13 @@ class TestEdgeCases:
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
@ -219,13 +219,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
@ -242,13 +242,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
|
||||
@ -314,8 +314,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
# Mock __instruction_modify_common call via invoke_llm
|
||||
mock_response = MagicMock()
|
||||
@ -328,12 +328,12 @@ class TestLLMGenerator:
|
||||
assert result == {"modified": "prompt"}
|
||||
|
||||
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
last_run = MagicMock()
|
||||
last_run.query = "q"
|
||||
last_run.answer = "a"
|
||||
last_run.error = "e"
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
|
||||
mock_scalar.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
@ -483,8 +483,8 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
|
||||
# Testing placeholders replacement via instruction_modify_legacy for convenience
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||
@ -504,8 +504,8 @@ class TestLLMGenerator:
|
||||
assert "current_val" in user_msg_dict["instruction"]
|
||||
|
||||
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No braces here"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@ -516,8 +516,8 @@ class TestLLMGenerator:
|
||||
assert "Could not find a valid JSON object" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@ -556,8 +556,8 @@ class TestLLMGenerator:
|
||||
)
|
||||
|
||||
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@ -566,8 +566,8 @@ class TestLLMGenerator:
|
||||
assert "Failed to generate code" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@ -576,8 +576,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No JSON here"
|
||||
|
||||
@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation:
|
||||
PluginAppBackwardsInvocation._get_user("uid")
|
||||
|
||||
def test_get_app_returns_app(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
app_obj = MagicMock(id="app")
|
||||
query_chain.first.return_value = app_obj
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
|
||||
|
||||
def test_get_app_raises_when_missing(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
query_chain.first.return_value = None
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
def test_get_app_raises_when_query_fails(self, mocker):
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels():
|
||||
def test_tool_label_manager_update_tool_labels_db():
|
||||
controller = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
delete_query = mock_db.session.query.return_value.where.return_value
|
||||
delete_query.delete.return_value = None
|
||||
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
|
||||
|
||||
delete_query.delete.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
# only one valid unique label should be inserted.
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -220,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
builtin_provider
|
||||
)
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key": "secret"}
|
||||
cache = Mock()
|
||||
@ -274,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
|
||||
)
|
||||
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"token": "old"}
|
||||
encrypter.encrypt.return_value = {"token": "encrypted"}
|
||||
@ -698,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
mock_db.session.scalar.return_value = provider
|
||||
with patch(
|
||||
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
|
||||
) as mock_from_db:
|
||||
@ -730,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
mock_db.session.scalar.return_value = provider
|
||||
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key_value": "secret"}
|
||||
@ -750,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
|
||||
def test_get_api_provider_controller_not_found_raises():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
|
||||
ToolManager.get_api_provider_controller("tenant-1", "missing")
|
||||
|
||||
@ -809,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api():
|
||||
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
|
||||
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
|
||||
mock_db.session.scalar.side_effect = [workflow_provider, api_provider]
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
|
||||
|
||||
|
||||
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
|
||||
|
||||
@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
|
||||
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
|
||||
db_session.get.return_value = dataset
|
||||
|
||||
tool = SingleDatasetRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
|
||||
db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
db_session.get.side_effect = [
|
||||
SimpleNamespace(id="dataset-2", name="Dataset Two"),
|
||||
SimpleNamespace(id="dataset-1", name="Dataset One"),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user