mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
test: unit test cases for console.datasets module (#32179)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
This commit is contained in:
@ -0,0 +1,817 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_auth import (
|
||||
DatasourceAuth,
|
||||
DatasourceAuthDefaultApi,
|
||||
DatasourceAuthDeleteApi,
|
||||
DatasourceAuthListApi,
|
||||
DatasourceAuthOauthCustomClient,
|
||||
DatasourceAuthUpdateApi,
|
||||
DatasourceHardCodeAuthListApi,
|
||||
DatasourceOAuthCallback,
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
DatasourceUpdateProviderNameApi,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
def test_get_success(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=cred-1"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-1",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_no_oauth_config(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_without_credential_id_sets_cookie(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-123",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "context_id" in response.headers.get("Set-Cookie")
|
||||
|
||||
|
||||
class TestDatasourceOAuthCallback:
|
||||
def test_callback_success_new_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {"name": "test"}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_callback_missing_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_invalid_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=bad"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_oauth_config_not_found(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
context = {"user_id": "u", "tenant_id": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_reauthorize_existing_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {} # avatar + name missing
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"reauthorize_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "/oauth-callback" in response.location
|
||||
|
||||
def test_callback_context_id_from_cookie(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
|
||||
class TestDatasourceAuth:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "val"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_invalid_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "bad"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
side_effect=CredentialsValidateFailedError("invalid"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"]
|
||||
|
||||
def test_post_missing_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceAuthDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_missing_credential_id(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceAuthUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_credentials_none(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": None}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
def test_update_name_only(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_empty_credentials_dict(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
|
||||
class TestDatasourceAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_auth_list_empty(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
def test_hardcode_list_empty(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceHardCodeAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthOauthCustomClient:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"client_params": {}, "enable_oauth_custom_client": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_empty_payload(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_disabled_flag(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"client_params": {"a": 1},
|
||||
"enable_oauth_custom_client": False,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
) as setup_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthDefaultApi:
|
||||
def test_set_default_success(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"set_default_datasource_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_default_missing_id(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceUpdateProviderNameApi:
|
||||
def test_update_name_success(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_provider_name",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_update_name_too_long(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credential_id": "id",
|
||||
"name": "x" * 101,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_update_name_missing_credential_id(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "Valid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
@ -0,0 +1,143 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
|
||||
DataSourceContentPreviewApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDataSourceContentPreviewApi:
|
||||
def _valid_payload(self):
|
||||
return {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
node_id = "node-1"
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
preview_result = {"content": "preview data"}
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = preview_result
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, node_id)
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once_with(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=payload["inputs"],
|
||||
account=account,
|
||||
datasource_type=payload["datasource_type"],
|
||||
is_published=True,
|
||||
credential_id=payload["credential_id"],
|
||||
)
|
||||
assert status == 200
|
||||
assert response == preview_result
|
||||
|
||||
def test_post_forbidden_non_account_user(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
MagicMock(), # NOT Account
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
# datasource_type missing
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_without_credential_id(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = {"ok": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, "node-1")
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once()
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
CustomizedPipelineTemplateApi,
|
||||
PipelineTemplateDetailApi,
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [{"id": "t1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in&language=en-US"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
|
||||
return_value=templates,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == templates
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
template = {"id": "tpl-1"}
|
||||
|
||||
service = MagicMock()
|
||||
service.get_pipeline_template_detail.return_value = template
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == template
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
|
||||
) as update_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert response == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
|
||||
) as delete_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
|
||||
def test_post_template_not_found(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response = method(api, "pipeline-1")
|
||||
|
||||
service.publish_customized_pipeline_template.assert_called_once()
|
||||
assert response == {"result": "success"}
|
||||
@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
|
||||
CreateEmptyRagPipelineDatasetApi,
|
||||
CreateRagPipelineDatasetApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
import_info = {"dataset_id": "ds-1"}
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == import_info
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_dataset_name_duplicate(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
|
||||
return_value={"id": "ds-1"},
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == {"id": "ds-1"}
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
@ -0,0 +1,324 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
RagPipelineSystemVariableCollectionApi,
|
||||
RagPipelineVariableApi,
|
||||
RagPipelineVariableCollectionApi,
|
||||
RagPipelineVariableResetApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db():
|
||||
db = MagicMock()
|
||||
db.engine = MagicMock()
|
||||
db.session.return_value = MagicMock()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.has_edit_permission = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restx_config(app):
|
||||
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
|
||||
|
||||
|
||||
class TestRagPipelineVariableCollectionApi:
|
||||
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = True
|
||||
|
||||
# IMPORTANT: RESTX expects .variables
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
draft_srv = MagicMock()
|
||||
draft_srv.list_variables_without_values.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=10"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=draft_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_delete_variables_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineNodeVariableCollectionApi:
|
||||
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_node_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_node_variables_invalid_node(self, app, editor_user):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class TestRagPipelineVariableApi:
|
||||
def test_get_variable_not_found(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, MagicMock(), "v1")
|
||||
|
||||
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(id="p1", tenant_id="t1")
|
||||
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
payload = {"value": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, pipeline, "v1")
|
||||
|
||||
def test_delete_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineVariableResetApi:
|
||||
def test_reset_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableResetApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
workflow = MagicMock()
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
srv.reset_variable.return_value = variable
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
|
||||
return_value={"id": "v1"},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result == {"id": "v1"}
|
||||
|
||||
|
||||
class TestSystemAndEnvironmentVariablesApi:
|
||||
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineSystemVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_system_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_environment_variables_success(self, app, editor_user):
|
||||
api = RagPipelineEnvironmentVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
env_var = MagicMock(
|
||||
id="e1",
|
||||
name="ENV",
|
||||
description="d",
|
||||
selector="s",
|
||||
value_type=MagicMock(value="string"),
|
||||
value="x",
|
||||
)
|
||||
|
||||
workflow = MagicMock(environment_variables=[env_var])
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
@ -0,0 +1,329 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
RagPipelineImportApi,
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
RagPipelineImportConfirmApi,
|
||||
)
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
"yaml_content": "content",
|
||||
"name": "Test",
|
||||
}
|
||||
|
||||
def test_post_success_200(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"status": "success"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"status": "success"}
|
||||
|
||||
def test_post_failed_400(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"status": "failed"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert response == {"status": "failed"}
|
||||
|
||||
def test_post_pending_202(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
result.model_dump.return_value = {"status": "pending"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 202
|
||||
assert response == {"status": "pending"}
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"ok": True}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
|
||||
def test_confirm_failed(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"ok": False}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response == {"ok": False}
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
result = MagicMock()
|
||||
result.model_dump.return_value = {"deps": []}
|
||||
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"deps": []}
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": {"yaml": "data"}}
|
||||
@ -0,0 +1,688 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
DraftRagPipelineApi,
|
||||
DraftRagPipelineRunApi,
|
||||
PublishedAllRagPipelineApi,
|
||||
PublishedRagPipelineApi,
|
||||
PublishedRagPipelineRunApi,
|
||||
RagPipelineByIdApi,
|
||||
RagPipelineDatasourceVariableApi,
|
||||
RagPipelineDraftNodeRunApi,
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
RagPipelineRecommendedPluginApi,
|
||||
RagPipelineTaskStopApi,
|
||||
RagPipelineTransformApi,
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
def test_get_draft_not_exist(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_hash_not_match(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"graph": {}, "features": {}}),
|
||||
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_invalid_text_plain(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node")
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_iteration_node_conversation_not_exists(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
side_effect=services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
def test_loop_node_success(self, app):
|
||||
api = RagPipelineDraftRunLoopNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline, "node") == {"ok": True}
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline) == {"ok": True}
|
||||
|
||||
def test_draft_run_rate_limit(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
|
||||
),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.run_draft_workflow_node.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert "created_at" in result
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
|
||||
) as stop_mock,
|
||||
):
|
||||
result = method(api, pipeline, "task-1")
|
||||
stop_mock.assert_called_once()
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_transform_forbidden(self, app):
|
||||
api = RagPipelineTransformApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds1")
|
||||
|
||||
def test_recommended_plugins(self, app):
|
||||
api = RagPipelineRecommendedPluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_recommended_plugins.return_value = [{"id": "p1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=all"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result == [{"id": "p1"}]
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
"response_mode": "blocking",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_published_run_rate_limit(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_default_block_config.return_value = {"k": "v"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?q={}"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "llm")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
def test_get_block_config_invalid_json(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
with app.test_request_context("/?q=bad-json"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "llm")
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?user_id=u2"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
def test_patch_no_fields(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(type(console_ns), "payload", {}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, pipeline, "w1")
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
node_exec = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
service.get_node_last_run.return_value = node_exec
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
assert result == node_exec
|
||||
|
||||
def test_last_run_not_found(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node1")
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"datasource_type": "db",
|
||||
"datasource_info": {},
|
||||
"start_node_id": "n1",
|
||||
"start_node_title": "Node",
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
service.set_datasource_variables.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result is not None
|
||||
Reference in New Issue
Block a user