Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing

This commit is contained in:
yyh
2026-03-09 17:41:13 +08:00
31 changed files with 14364 additions and 5 deletions

View File

@ -24,6 +24,21 @@ updates:
schedule:
interval: "weekly"
open-pull-requests-limit: 2
ignore:
- dependency-name: "@sentry/react"
update-types: ["version-update:semver-major"]
- dependency-name: "tailwindcss"
update-types: ["version-update:semver-major"]
- dependency-name: "echarts"
update-types: ["version-update:semver-major"]
- dependency-name: "uuid"
update-types: ["version-update:semver-major"]
- dependency-name: "react-markdown"
update-types: ["version-update:semver-major"]
- dependency-name: "react-syntax-highlighter"
update-types: ["version-update:semver-major"]
- dependency-name: "react-window"
update-types: ["version-update:semver-major"]
groups:
lexical:
patterns:
@ -33,6 +48,9 @@ updates:
patterns:
- "storybook"
- "@storybook/*"
eslint-group:
patterns:
- "*eslint*"
npm-dependencies:
patterns:
- "*"
@ -41,3 +59,4 @@ updates:
- "@lexical/*"
- "storybook"
- "@storybook/*"
- "*eslint*"

View File

@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource):
console_ns.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
custom="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@ -0,0 +1,817 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.datasource_auth import (
DatasourceAuth,
DatasourceAuthDefaultApi,
DatasourceAuthDeleteApi,
DatasourceAuthListApi,
DatasourceAuthOauthCustomClient,
DatasourceAuthUpdateApi,
DatasourceHardCodeAuthListApi,
DatasourceOAuthCallback,
DatasourcePluginOAuthAuthorizationUrl,
DatasourceUpdateProviderNameApi,
)
from core.plugin.impl.oauth import OAuthHandler
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDatasourcePluginOAuthAuthorizationUrl:
def test_get_success(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/?credential_id=cred-1"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthProxyService,
"create_proxy_context",
return_value="ctx-1",
),
patch.object(
OAuthHandler,
"get_authorization_url",
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
assert response.status_code == 200
def test_get_no_oauth_config(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value=None,
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_without_credential_id_sets_cookie(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthProxyService,
"create_proxy_context",
return_value="ctx-123",
),
patch.object(
OAuthHandler,
"get_authorization_url",
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
assert response.status_code == 200
assert "context_id" in response.headers.get("Set-Cookie")
class TestDatasourceOAuthCallback:
def test_callback_success_new_credential(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {"name": "test"}
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": None,
}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"add_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
def test_callback_missing_context(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_invalid_context(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
with (
app.test_request_context("/?context_id=bad"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_oauth_config_not_found(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
context = {"user_id": "u", "tenant_id": "t"}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "notion")
def test_callback_reauthorize_existing_credential(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {} # avatar + name missing
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": "cred-1",
}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"reauthorize_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
assert "/oauth-callback" in response.location
def test_callback_context_id_from_cookie(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {}
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": None,
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"add_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
class TestDatasourceAuth:
def test_post_success(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {"credentials": {"key": "val"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_post_invalid_credentials(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {"credentials": {"key": "bad"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
side_effect=CredentialsValidateFailedError("invalid"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_success(self, app):
api = DatasourceAuth()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api, "notion")
assert status == 200
assert response["result"]
def test_post_missing_credentials(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_empty_list(self, app):
api = DatasourceAuth()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[],
),
):
response, status = method(api, "notion")
assert status == 200
assert response["result"] == []
class TestDatasourceAuthDeleteApi:
def test_delete_success(self, app):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
payload = {"credential_id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_delete_missing_credential_id(self, app):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
class TestDatasourceAuthUpdateApi:
def test_update_success(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": {"k": "v"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 201
def test_update_with_credentials_none(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": None}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
response, status = method(api, "notion")
update_mock.assert_called_once()
assert status == 201
def test_update_name_only(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
_, status = method(api, "notion")
assert status == 201
def test_update_with_empty_credentials_dict(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": {}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
_, status = method(api, "notion")
update_mock.assert_called_once()
assert status == 201
class TestDatasourceAuthListApi:
def test_list_success(self, app):
api = DatasourceAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
assert status == 200
def test_auth_list_empty(self, app):
api = DatasourceAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
assert status == 200
assert response["result"] == []
def test_hardcode_list_empty(self, app):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
assert status == 200
assert response["result"] == []
class TestDatasourceHardCodeAuthListApi:
def test_list_success(self, app):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
assert status == 200
class TestDatasourceAuthOauthCustomClient:
def test_post_success(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {"client_params": {}, "enable_oauth_custom_client": True}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_delete_success(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_post_empty_payload(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
_, status = method(api, "notion")
assert status == 200
def test_post_disabled_flag(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {
"client_params": {"a": 1},
"enable_oauth_custom_client": False,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
) as setup_mock,
):
_, status = method(api, "notion")
setup_mock.assert_called_once()
assert status == 200
class TestDatasourceAuthDefaultApi:
def test_set_default_success(self, app):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
payload = {"id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"set_default_datasource_provider",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_default_missing_id(self, app):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
class TestDatasourceUpdateProviderNameApi:
def test_update_name_success(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_provider_name",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_update_name_too_long(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {
"credential_id": "id",
"name": "x" * 101,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_update_name_missing_credential_id(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {"name": "Valid"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")

View File

@ -0,0 +1,143 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
DataSourceContentPreviewApi,
)
from models import Account
from models.dataset import Pipeline
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDataSourceContentPreviewApi:
def _valid_payload(self):
return {
"inputs": {"query": "hello"},
"datasource_type": "notion",
"credential_id": "cred-1",
}
def test_post_success(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = self._valid_payload()
pipeline = MagicMock(spec=Pipeline)
node_id = "node-1"
account = MagicMock(spec=Account)
preview_result = {"content": "preview data"}
service_instance = MagicMock()
service_instance.run_datasource_node_preview.return_value = preview_result
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
return_value=service_instance,
),
):
response, status = method(api, pipeline, node_id)
service_instance.run_datasource_node_preview.assert_called_once_with(
pipeline=pipeline,
node_id=node_id,
user_inputs=payload["inputs"],
account=account,
datasource_type=payload["datasource_type"],
is_published=True,
credential_id=payload["credential_id"],
)
assert status == 200
assert response == preview_result
def test_post_forbidden_non_account_user(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = self._valid_payload()
pipeline = MagicMock(spec=Pipeline)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
MagicMock(), # NOT Account
),
):
with pytest.raises(Forbidden):
method(api, pipeline, "node-1")
def test_post_invalid_payload(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = {
"inputs": {"query": "hello"},
# datasource_type missing
}
pipeline = MagicMock(spec=Pipeline)
account = MagicMock(spec=Account)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
):
with pytest.raises(ValueError):
method(api, pipeline, "node-1")
def test_post_without_credential_id(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = {
"inputs": {"query": "hello"},
"datasource_type": "notion",
"credential_id": None,
}
pipeline = MagicMock(spec=Pipeline)
account = MagicMock(spec=Account)
service_instance = MagicMock()
service_instance.run_datasource_node_preview.return_value = {"ok": True}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
return_value=service_instance,
),
):
response, status = method(api, pipeline, "node-1")
service_instance.run_datasource_node_preview.assert_called_once()
assert status == 200
assert response == {"ok": True}

View File

@ -0,0 +1,187 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
CustomizedPipelineTemplateApi,
PipelineTemplateDetailApi,
PipelineTemplateListApi,
PublishCustomizedPipelineTemplateApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestPipelineTemplateListApi:
def test_get_success(self, app):
api = PipelineTemplateListApi()
method = unwrap(api.get)
templates = [{"id": "t1"}]
with (
app.test_request_context("/?type=built-in&language=en-US"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
return_value=templates,
),
):
response, status = method(api)
assert status == 200
assert response == templates
class TestPipelineTemplateDetailApi:
def test_get_success(self, app):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
template = {"id": "tpl-1"}
service = MagicMock()
service.get_pipeline_template_detail.return_value = template
with (
app.test_request_context("/?type=built-in"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
return_value=service,
),
):
response, status = method(api, "tpl-1")
assert status == 200
assert response == template
class TestCustomizedPipelineTemplateApi:
def test_patch_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
payload = {
"name": "Template",
"description": "Desc",
"icon_info": {"icon": "📘"},
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
) as update_mock,
):
response = method(api, "tpl-1")
update_mock.assert_called_once()
assert response == 200
def test_delete_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
) as delete_mock,
):
response = method(api, "tpl-1")
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
template = MagicMock()
template.yaml_content = "yaml-data"
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = template
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
response, status = method(api, "tpl-1")
assert status == 200
assert response == {"data": "yaml-data"}
def test_post_template_not_found(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
with pytest.raises(ValueError):
method(api, "tpl-1")
class TestPublishCustomizedPipelineTemplateApi:
def test_post_success(self, app):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)
payload = {
"name": "Template",
"description": "Desc",
"icon_info": {"icon": "📘"},
}
service = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
return_value=service,
),
):
response = method(api, "pipeline-1")
service.publish_customized_pipeline_template.assert_called_once()
assert response == {"result": "success"}

View File

@ -0,0 +1,187 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
import services
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
CreateEmptyRagPipelineDatasetApi,
CreateRagPipelineDatasetApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestCreateRagPipelineDatasetApi:
def _valid_payload(self):
return {"yaml_content": "name: test"}
def test_post_success(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=True)
import_info = {"dataset_id": "ds-1"}
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.return_value = import_info
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
response, status = method(api)
assert status == 201
assert response == import_info
def test_post_forbidden_non_editor(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)
def test_post_dataset_name_duplicate(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=True)
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = {}
user = MagicMock(is_dataset_editor=True)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api)
class TestCreateEmptyRagPipelineDatasetApi:
def test_post_success(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
user = MagicMock(is_dataset_editor=True)
dataset = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
return_value={"id": "ds-1"},
),
):
response, status = method(api)
assert status == 201
assert response == {"id": "ds-1"}
def test_post_forbidden_non_editor(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)

View File

@ -0,0 +1,324 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Response
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
RagPipelineEnvironmentVariableCollectionApi,
RagPipelineNodeVariableCollectionApi,
RagPipelineSystemVariableCollectionApi,
RagPipelineVariableApi,
RagPipelineVariableCollectionApi,
RagPipelineVariableResetApi,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.types import SegmentType
from models.account import Account
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def fake_db():
db = MagicMock()
db.engine = MagicMock()
db.session.return_value = MagicMock()
return db
@pytest.fixture
def editor_user():
user = MagicMock(spec=Account)
user.has_edit_permission = True
return user
@pytest.fixture
def restx_config(app):
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
class TestRagPipelineVariableCollectionApi:
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
rag_srv = MagicMock()
rag_srv.is_workflow_exist.return_value = True
# IMPORTANT: RESTX expects .variables
var_list = MagicMock()
var_list.variables = []
draft_srv = MagicMock()
draft_srv.list_variables_without_values.return_value = var_list
with (
app.test_request_context("/?page=1&limit=10"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=draft_srv,
),
):
result = method(api, pipeline)
assert result["items"] == []
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock()
rag_srv = MagicMock()
rag_srv.is_workflow_exist.return_value = False
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
):
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_delete_variables_success(self, app, fake_db, editor_user):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.delete)
pipeline = MagicMock(id="p1")
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
):
result = method(api, pipeline)
assert isinstance(result, Response)
assert result.status_code == 204
class TestRagPipelineNodeVariableCollectionApi:
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineNodeVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
var_list = MagicMock()
var_list.variables = []
srv = MagicMock()
srv.list_node_variables.return_value = var_list
with (
app.test_request_context("/"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline, "node1")
assert result["items"] == []
def test_get_node_variables_invalid_node(self, app, editor_user):
api = RagPipelineNodeVariableCollectionApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
):
with pytest.raises(InvalidArgumentError):
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
class TestRagPipelineVariableApi:
def test_get_variable_not_found(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.get)
srv = MagicMock()
srv.get_variable.return_value = None
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
with pytest.raises(NotFoundError):
method(api, MagicMock(), "v1")
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.patch)
pipeline = MagicMock(id="p1", tenant_id="t1")
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
srv = MagicMock()
srv.get_variable.return_value = variable
payload = {"value": "invalid"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
with pytest.raises(InvalidArgumentError):
method(api, pipeline, "v1")
def test_delete_variable_success(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.delete)
pipeline = MagicMock(id="p1")
variable = MagicMock(app_id="p1")
srv = MagicMock()
srv.get_variable.return_value = variable
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline, "v1")
assert result.status_code == 204
class TestRagPipelineVariableResetApi:
def test_reset_variable_success(self, app, fake_db, editor_user):
api = RagPipelineVariableResetApi()
method = unwrap(api.put)
pipeline = MagicMock(id="p1")
workflow = MagicMock()
variable = MagicMock(app_id="p1")
srv = MagicMock()
srv.get_variable.return_value = variable
srv.reset_variable.return_value = variable
rag_srv = MagicMock()
rag_srv.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
return_value={"id": "v1"},
),
):
result = method(api, pipeline, "v1")
assert result == {"id": "v1"}
class TestSystemAndEnvironmentVariablesApi:
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineSystemVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
var_list = MagicMock()
var_list.variables = []
srv = MagicMock()
srv.list_system_variables.return_value = var_list
with (
app.test_request_context("/"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline)
assert result["items"] == []
def test_environment_variables_success(self, app, editor_user):
api = RagPipelineEnvironmentVariableCollectionApi()
method = unwrap(api.get)
env_var = MagicMock(
id="e1",
name="ENV",
description="d",
selector="s",
value_type=MagicMock(value="string"),
value="x",
)
workflow = MagicMock(environment_variables=[env_var])
pipeline = MagicMock(id="p1")
rag_srv = MagicMock()
rag_srv.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
):
result = method(api, pipeline)
assert len(result["items"]) == 1

View File

@ -0,0 +1,329 @@
from unittest.mock import MagicMock, patch
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
RagPipelineExportApi,
RagPipelineImportApi,
RagPipelineImportCheckDependenciesApi,
RagPipelineImportConfirmApi,
)
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestRagPipelineImportApi:
def _payload(self, mode="create"):
return {
"mode": mode,
"yaml_content": "content",
"name": "Test",
}
def test_post_success_200(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = "completed"
result.model_dump.return_value = {"status": "success"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 200
assert response == {"status": "success"}
def test_post_failed_400(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
result.model_dump.return_value = {"status": "failed"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 400
assert response == {"status": "failed"}
def test_post_pending_202(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.PENDING
result.model_dump.return_value = {"status": "pending"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 202
assert response == {"status": "pending"}
class TestRagPipelineImportConfirmApi:
def test_confirm_success(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
user = MagicMock()
result = MagicMock()
result.status = "completed"
result.model_dump.return_value = {"ok": True}
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
assert status == 200
assert response == {"ok": True}
def test_confirm_failed(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
result.model_dump.return_value = {"ok": False}
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
assert status == 400
assert response == {"ok": False}
class TestRagPipelineImportCheckDependenciesApi:
def test_get_success(self, app):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
pipeline = MagicMock(spec=Pipeline)
result = MagicMock()
result.model_dump.return_value = {"deps": []}
service = MagicMock()
service.check_dependencies.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, pipeline)
assert status == 200
assert response == {"deps": []}
class TestRagPipelineExportApi:
def test_get_with_include_secret(self, app):
api = RagPipelineExportApi()
method = unwrap(api.get)
pipeline = MagicMock(spec=Pipeline)
service = MagicMock()
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/?include_secret=true"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, pipeline)
assert status == 200
assert response == {"data": {"yaml": "data"}}

View File

@ -0,0 +1,688 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
DefaultRagPipelineBlockConfigApi,
DraftRagPipelineApi,
DraftRagPipelineRunApi,
PublishedAllRagPipelineApi,
PublishedRagPipelineApi,
PublishedRagPipelineRunApi,
RagPipelineByIdApi,
RagPipelineDatasourceVariableApi,
RagPipelineDraftNodeRunApi,
RagPipelineDraftRunIterationNodeApi,
RagPipelineDraftRunLoopNodeApi,
RagPipelineRecommendedPluginApi,
RagPipelineTaskStopApi,
RagPipelineTransformApi,
RagPipelineWorkflowLastRunApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDraftWorkflowApi:
def test_get_draft_success(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
workflow = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result == workflow
def test_get_draft_not_exist(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_sync_hash_not_match(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
service = MagicMock()
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
with (
app.test_request_context("/", json={"graph": {}, "features": {}}),
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(DraftWorkflowNotSync):
method(api, pipeline)
def test_sync_invalid_text_plain(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
response, status = method(api, pipeline)
assert status == 400
class TestDraftRunNodes:
def test_iteration_node_success(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
result = method(api, pipeline, "node")
assert result == {"ok": True}
def test_iteration_node_conversation_not_exists(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
side_effect=services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(api, pipeline, "node")
def test_loop_node_success(self, app):
api = RagPipelineDraftRunLoopNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
assert method(api, pipeline, "node") == {"ok": True}
class TestPipelineRunApis:
def test_draft_run_success(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
assert method(api, pipeline) == {"ok": True}
def test_draft_run_rate_limit(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context(
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
),
patch.object(
type(console_ns),
"payload",
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
class TestDraftNodeRun:
def test_execution_not_found(self, app):
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
service = MagicMock()
service.run_draft_workflow_node.return_value = None
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(ValueError):
method(api, pipeline, "node")
class TestPublishedPipelineApis:
def test_publish_success(self, app):
api = PublishedRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="u1")
workflow = MagicMock(
id="w1",
created_at=datetime.utcnow(),
)
session = MagicMock()
session.merge.return_value = pipeline
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
service = MagicMock()
service.publish_workflow.return_value = workflow
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result["result"] == "success"
assert "created_at" in result
class TestMiscApis:
def test_task_stop(self, app):
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="u1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
) as stop_mock,
):
result = method(api, pipeline, "task-1")
stop_mock.assert_called_once()
assert result["result"] == "success"
def test_transform_forbidden(self, app):
api = RagPipelineTransformApi()
method = unwrap(api.post)
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, "ds1")
def test_recommended_plugins(self, app):
api = RagPipelineRecommendedPluginApi()
method = unwrap(api.get)
service = MagicMock()
service.get_recommended_plugins.return_value = [{"id": "p1"}]
with (
app.test_request_context("/?type=all"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api)
assert result == [{"id": "p1"}]
class TestPublishedRagPipelineRunApi:
def test_published_run_success(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
"response_mode": "blocking",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
result = method(api, pipeline)
assert result == {"ok": True}
def test_published_run_rate_limit(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
class TestDefaultBlockConfigApi:
def test_get_block_config_success(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_default_block_config.return_value = {"k": "v"}
with (
app.test_request_context("/?q={}"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "llm")
assert result == {"k": "v"}
def test_get_block_config_invalid_json(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
with app.test_request_context("/?q=bad-json"):
with pytest.raises(ValueError):
method(api, pipeline, "llm")
class TestPublishedAllRagPipelineApi:
def test_get_published_workflows_success(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
service = MagicMock()
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result["items"] == [{"id": "w1"}]
assert result["has_more"] is False
def test_get_published_workflows_forbidden(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
with (
app.test_request_context("/?user_id=u2"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, pipeline)
class TestRagPipelineByIdApi:
def test_patch_success(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock(tenant_id="t1")
user = MagicMock(id="u1")
workflow = MagicMock()
service = MagicMock()
service.update_workflow.return_value = workflow
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
payload = {"marked_name": "test"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "w1")
assert result == workflow
def test_patch_no_fields(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={}),
patch.object(type(console_ns), "payload", {}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
result, status = method(api, pipeline, "w1")
assert status == 400
class TestRagPipelineWorkflowLastRunApi:
def test_last_run_success(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
workflow = MagicMock()
node_exec = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = workflow
service.get_node_last_run.return_value = node_exec
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "node1")
assert result == node_exec
def test_last_run_not_found(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(NotFound):
method(api, pipeline, "node1")
class TestRagPipelineDatasourceVariableApi:
def test_set_datasource_variables_success(self, app):
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"datasource_type": "db",
"datasource_info": {},
"start_node_id": "n1",
"start_node_title": "Node",
}
service = MagicMock()
service.set_datasource_variables.return_value = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result is not None

View File

@ -0,0 +1,444 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import NotFound
from controllers.console.datasets import data_source
from controllers.console.datasets.data_source import (
DataSourceApi,
DataSourceNotionApi,
DataSourceNotionDatasetSyncApi,
DataSourceNotionDocumentSyncApi,
DataSourceNotionListApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def tenant_ctx():
return (MagicMock(id="u1"), "tenant-1")
@pytest.fixture
def patch_tenant(tenant_ctx):
with patch(
"controllers.console.datasets.data_source.current_account_with_tenant",
return_value=tenant_ctx,
):
yield
@pytest.fixture
def mock_engine():
with patch.object(
type(data_source.db),
"engine",
new_callable=PropertyMock,
return_value=MagicMock(),
):
yield
class TestDataSourceApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
binding = MagicMock(
id="b1",
provider="notion",
created_at="now",
disabled=False,
source_info={},
)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.db.session.scalars",
return_value=MagicMock(all=lambda: [binding]),
),
):
response, status = method(api)
assert status == 200
assert response["data"][0]["is_bound"] is True
def test_get_no_bindings(self, app, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.db.session.scalars",
return_value=MagicMock(all=lambda: []),
),
):
response, status = method(api)
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "enable")
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "disable")
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = None
with pytest.raises(NotFound):
method(api, "b1", "enable")
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "enable")
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "disable")
class TestDataSourceNotionListApi:
def test_get_credential_not_found(self, app, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api)
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
page = MagicMock(
page_id="p1",
page_name="Page 1",
type="page",
parent_id="parent",
page_icon=None,
)
online_document_message = MagicMock(
result=[
MagicMock(
workspace_id="w1",
workspace_name="My Workspace",
workspace_icon="icon",
pages=[page],
)
]
)
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=MagicMock(
get_online_document_pages=lambda **kw: iter([online_document_message]),
datasource_provider_type=lambda: None,
),
),
):
response, status = method(api)
assert status == 200
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
page = MagicMock(
page_id="p1",
page_name="Page 1",
type="page",
parent_id="parent",
page_icon=None,
)
online_document_message = MagicMock(
result=[
MagicMock(
workspace_id="w1",
workspace_name="My Workspace",
workspace_icon="icon",
pages=[page],
)
]
)
dataset = MagicMock(data_source_type="notion_import")
document = MagicMock(data_source_info='{"notion_page_id": "p1"}')
with (
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=MagicMock(
get_online_document_pages=lambda **kw: iter([online_document_message]),
datasource_provider_type=lambda: None,
),
),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = [document]
response, status = method(api)
assert status == 200
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
dataset = MagicMock(data_source_type="other_type")
with (
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session"),
):
with pytest.raises(ValueError):
method(api)
class TestDataSourceNotionApi:
def test_get_preview_success(self, app, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"integration_secret": "t"},
),
patch(
"controllers.console.datasets.data_source.NotionExtractor",
return_value=extractor,
),
):
response, status = method(api, "p1", "page")
assert status == 200
def test_post_indexing_estimate_success(self, app, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.post)
payload = {
"notion_info_list": [
{
"workspace_id": "w1",
"credential_id": "c1",
"pages": [{"page_id": "p1", "type": "page"}],
}
],
"process_rule": {"rules": {}},
"doc_form": "text_model",
"doc_language": "English",
}
with (
app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}),
patch(
"controllers.console.datasets.data_source.DocumentService.estimate_args_validate",
),
patch(
"controllers.console.datasets.data_source.IndexingRunner.indexing_estimate",
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
),
):
response, status = method(api)
assert status == 200
class TestDataSourceNotionDatasetSyncApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id",
return_value=[MagicMock(id="d1")],
),
patch(
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
return_value=None,
),
):
response, status = method(api, "ds-1")
assert status == 200
def test_get_dataset_not_found(self, app, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
class TestDataSourceNotionDocumentSyncApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1")
assert status == 200
def test_get_document_not_found(self, app, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,399 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.external import (
BedrockRetrievalApi,
ExternalApiTemplateApi,
ExternalApiTemplateListApi,
ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi,
)
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_external_dataset")
app.config["TESTING"] = True
return app
@pytest.fixture
def current_user():
user = MagicMock()
user.id = "user-1"
user.is_dataset_editor = True
user.has_edit_permission = True
user.is_dataset_operator = True
return user
@pytest.fixture(autouse=True)
def mock_auth(mocker, current_user):
mocker.patch(
"controllers.console.datasets.external.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
)
class TestExternalApiTemplateListApi:
def test_get_success(self, app):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
api_item = MagicMock()
api_item.to_dict.return_value = {"id": "1"}
with (
app.test_request_context("/?page=1&limit=20"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_apis",
return_value=([api_item], 1),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
def test_post_forbidden(self, app, current_user):
current_user.is_dataset_editor = False
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list"),
):
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list"),
patch.object(
ExternalDatasetService,
"create_external_knowledge_api",
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
class TestExternalApiTemplateApi:
def test_get_not_found(self, app):
api = ExternalApiTemplateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_api",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "api-id")
def test_delete_forbidden(self, app, current_user):
current_user.has_edit_permission = False
current_user.is_dataset_operator = False
api = ExternalApiTemplateApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "api-id")
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
"external_knowledge_id": "kid",
"name": "dataset",
}
dataset = MagicMock()
dataset.embedding_available = False
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.enable_qa = False
dataset.enable_vector_store = False
dataset.vector_store_setting = None
dataset.is_multimodal = False
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetService,
"create_external_dataset",
return_value=dataset,
),
):
_, status = method(api)
assert status == 201
def test_create_forbidden(self, app, current_user):
current_user.is_dataset_editor = False
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
"external_knowledge_id": "kid",
"name": "dataset",
}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
):
with pytest.raises(Forbidden):
method(api)
class TestExternalKnowledgeHitTestingApi:
def test_hit_testing_dataset_not_found(self, app):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-id")
def test_hit_testing_success(self, app):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
payload = {"query": "hello"}
dataset = MagicMock()
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission"),
patch.object(
HitTestingService,
"external_retrieve",
return_value={"ok": True},
),
):
resp = method(api, "dataset-id")
assert resp["ok"] is True
class TestBedrockRetrievalApi:
def test_bedrock_retrieval(self, app):
api = BedrockRetrievalApi()
method = unwrap(api.post)
payload = {
"retrieval_setting": {},
"query": "hello",
"knowledge_id": "kid",
}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetTestService,
"knowledge_retrieval",
return_value={"ok": True},
),
):
resp, status = method()
assert status == 200
assert resp["ok"] is True
class TestExternalApiTemplateListApiAdvanced:
def test_post_duplicate_name_error(self, app, mock_auth, current_user):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"),
patch(
"controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api",
side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_get_with_pagination(self, app, mock_auth, current_user):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
with (
app.test_request_context("/?page=1&limit=20"),
patch(
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
return_value=(templates, 25),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 25
assert len(resp["data"]) == 3
class TestExternalDatasetCreateApiAdvanced:
def test_create_forbidden(self, app, mock_auth, current_user):
"""Test creating external dataset without permission"""
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
current_user.is_dataset_editor = False
payload = {
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "ek-1",
"name": "new_dataset",
"description": "A dataset",
}
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(Forbidden):
method(api)
class TestExternalKnowledgeHitTestingApiAdvanced:
def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user):
"""Test hit testing on non-existent dataset"""
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
payload = {
"query": "test query",
"external_retrieval_model": None,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
dataset = MagicMock()
payload = {
"query": "test query",
"external_retrieval_model": {"type": "bm25"},
"metadata_filtering_conditions": {"status": "active"},
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
patch(
"controllers.console.datasets.external.HitTestingService.external_retrieve",
return_value={"results": []},
),
):
resp = method(api, "ds-1")
assert resp["results"] == []
class TestBedrockRetrievalApiAdvanced:
def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user):
api = BedrockRetrievalApi()
method = unwrap(api.post)
payload = {
"retrieval_setting": {},
"query": "test",
"knowledge_id": "k-1",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval",
side_effect=ValueError("Invalid settings"),
),
):
with pytest.raises(ValueError):
method()

View File

@ -0,0 +1,160 @@
import uuid
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.hit_testing import HitTestingApi
from controllers.console.datasets.hit_testing_base import HitTestingPayload
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_hit_testing")
app.config["TESTING"] = True
return app
@pytest.fixture
def dataset_id():
return uuid.uuid4()
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
"""Bypass all decorators on the API method."""
mocker.patch(
"controllers.console.datasets.hit_testing.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.login_required",
return_value=lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.account_initialization_required",
return_value=lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
return_value=lambda *_: (lambda f: f),
)
class TestHitTestingApi:
def test_hit_testing_success(self, app, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "what is vector search",
"top_k": 3,
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
return_value=dataset,
),
patch.object(
HitTestingApi,
"hit_testing_args_check",
),
patch.object(
HitTestingApi,
"perform_hit_testing",
return_value={"query": "what is vector search", "records": []},
),
):
result = method(api, dataset_id)
assert "query" in result
assert "records" in result
assert result["records"] == []
def test_hit_testing_dataset_not_found(self, app, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "test",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
side_effect=NotFound("Dataset not found"),
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
return_value=dataset,
),
patch.object(
HitTestingApi,
"hit_testing_args_check",
side_effect=ValueError("Invalid parameters"),
),
):
with pytest.raises(ValueError, match="Invalid parameters"):
method(api, dataset_id)

View File

@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.datasets.hit_testing_base import (
DatasetsHitTestingBase,
)
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@pytest.fixture
def account():
acc = MagicMock(spec=Account)
return acc
@pytest.fixture(autouse=True)
def patch_current_user(mocker, account):
"""Patch current_user to a valid Account."""
mocker.patch(
"controllers.console.datasets.hit_testing_base.current_user",
account,
)
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
class TestGetAndValidateDataset:
def test_success(self, dataset):
with (
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
):
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
assert result == dataset
def test_dataset_not_found(self):
with patch.object(
DatasetService,
"get_dataset",
return_value=None,
):
with pytest.raises(NotFound, match="Dataset not found"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
def test_permission_denied(self, dataset):
with (
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no access"),
),
):
with pytest.raises(Forbidden, match="no access"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
class TestHitTestingArgsCheck:
def test_args_check_called(self):
args = {"query": "test"}
with patch.object(
HitTestingService,
"hit_testing_args_check",
) as check_mock:
DatasetsHitTestingBase.hit_testing_args_check(args)
check_mock.assert_called_once_with(args)
class TestParseArgs:
def test_parse_args_success(self):
payload = {"query": "hello"}
result = DatasetsHitTestingBase.parse_args(payload)
assert result["query"] == "hello"
def test_parse_args_invalid(self):
payload = {"query": "x" * 300}
with pytest.raises(ValueError):
DatasetsHitTestingBase.parse_args(payload)
class TestPerformHitTesting:
def test_success(self, dataset):
response = {
"query": "hello",
"records": [],
}
with patch.object(
HitTestingService,
"retrieve",
return_value=response,
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
assert result["query"] == "hello"
assert result["records"] == []
def test_index_not_initialized(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=services.errors.index.IndexNotInitializedError(),
):
with pytest.raises(DatasetNotInitializedError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_provider_token_not_init(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ProviderTokenNotInitError("token missing"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_quota_exceeded(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=QuotaExceededError(),
):
with pytest.raises(ProviderQuotaExceededError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_model_not_supported(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ModelCurrentlyNotSupportError(),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_llm_bad_request(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=LLMBadRequestError("bad request"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_invoke_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=InvokeError("invoke failed"),
):
with pytest.raises(CompletionRequestError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_value_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ValueError("bad args"),
):
with pytest.raises(ValueError, match="bad args"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_unexpected_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=Exception("boom"),
):
with pytest.raises(InternalServerError, match="boom"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})

View File

@ -0,0 +1,362 @@
import uuid
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.metadata import (
DatasetMetadataApi,
DatasetMetadataBuiltInFieldActionApi,
DatasetMetadataBuiltInFieldApi,
DatasetMetadataCreateApi,
DocumentMetadataEditApi,
)
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from services.metadata_service import MetadataService
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_dataset_metadata")
app.config["TESTING"] = True
return app
@pytest.fixture
def current_user():
user = MagicMock()
user.id = "user-1"
return user
@pytest.fixture
def dataset():
ds = MagicMock()
ds.id = "dataset-1"
return ds
@pytest.fixture
def dataset_id():
return uuid.uuid4()
@pytest.fixture
def metadata_id():
return uuid.uuid4()
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
"""Bypass setup/login/license decorators."""
mocker.patch(
"controllers.console.datasets.metadata.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.login_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.account_initialization_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.enterprise_license_required",
lambda f: f,
)
class TestDatasetMetadataCreateApi:
def test_create_metadata_success(self, app, current_user, dataset, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.post)
payload = {"name": "author"}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
return_value=MagicMock(),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"create_metadata",
return_value={"id": "m1", "name": "author"},
),
):
result, status = method(api, dataset_id)
assert status == 201
assert result["name"] == "author"
def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.post)
valid_payload = {
"type": "string",
"name": "author",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=valid_payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
return_value=MagicMock(),
),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
class TestDatasetMetadataGetApi:
def test_get_metadata_success(self, app, dataset, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
MetadataService,
"get_dataset_metadatas",
return_value=[{"id": "m1"}],
),
):
result, status = method(api, dataset_id)
assert status == 200
assert isinstance(result, list)
def test_get_metadata_dataset_not_found(self, app, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, dataset_id)
class TestDatasetMetadataApi:
def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
api = DatasetMetadataApi()
method = unwrap(api.patch)
payload = {"name": "updated-name"}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"update_metadata_name",
return_value={"id": "m1", "name": "updated-name"},
),
):
result, status = method(api, dataset_id, metadata_id)
assert status == 200
assert result["name"] == "updated-name"
def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
api = DatasetMetadataApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"delete_metadata",
),
):
result, status = method(api, dataset_id, metadata_id)
assert status == 204
assert result["result"] == "success"
class TestDatasetMetadataBuiltInFieldApi:
def test_get_built_in_fields(self, app):
api = DatasetMetadataBuiltInFieldApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
MetadataService,
"get_built_in_fields",
return_value=["title", "source"],
),
):
result, status = method(api)
assert status == 200
assert result["fields"] == ["title", "source"]
class TestDatasetMetadataBuiltInFieldActionApi:
def test_enable_built_in_field(self, app, current_user, dataset, dataset_id):
api = DatasetMetadataBuiltInFieldActionApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"enable_built_in_field",
),
):
result, status = method(api, dataset_id, "enable")
assert status == 200
assert result["result"] == "success"
class TestDocumentMetadataEditApi:
def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id):
api = DocumentMetadataEditApi()
method = unwrap(api.post)
payload = {"operation": "add", "metadata": {}}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataOperationData,
"model_validate",
return_value=MagicMock(),
),
patch.object(
MetadataService,
"update_documents_metadata",
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result["result"] == "success"

View File

@ -0,0 +1,233 @@
from unittest.mock import Mock, PropertyMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.datasets.website import (
WebsiteCrawlApi,
WebsiteCrawlStatusApi,
)
from services.website_service import (
WebsiteCrawlApiRequest,
WebsiteCrawlStatusApiRequest,
WebsiteService,
)
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_website_crawl")
app.config["TESTING"] = True
return app
@pytest.fixture(autouse=True)
def bypass_auth_and_setup(mocker):
"""Bypass setup/login/account decorators."""
mocker.patch(
"controllers.console.datasets.website.login_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.website.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.website.account_initialization_required",
lambda f: f,
)
class TestWebsiteCrawlApi:
def test_crawl_success(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "https://example.com",
"options": {"depth": 1},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mock_request = Mock(spec=WebsiteCrawlApiRequest)
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"crawl_url",
return_value={"job_id": "job-1"},
)
result, status = method(api)
assert status == 200
assert result["job_id"] == "job-1"
def test_crawl_invalid_payload(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "bad-url",
"options": {},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
side_effect=ValueError("invalid payload"),
)
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
method(api)
def test_crawl_service_error(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "https://example.com",
"options": {},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mock_request = Mock(spec=WebsiteCrawlApiRequest)
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"crawl_url",
side_effect=Exception("crawl failed"),
)
with pytest.raises(WebsiteCrawlError, match="crawl failed"):
method(api)
class TestWebsiteCrawlStatusApi:
def test_get_status_success(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"get_crawl_status_typed",
return_value={"status": "completed"},
)
result, status = method(api, job_id)
assert status == 200
assert result["status"] == "completed"
def test_get_status_invalid_provider(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
side_effect=ValueError("invalid provider"),
)
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
method(api, job_id)
def test_get_status_service_error(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"get_crawl_status_typed",
side_effect=Exception("status lookup failed"),
)
with pytest.raises(WebsiteCrawlError, match="status lookup failed"):
method(api, job_id)

View File

@ -0,0 +1,117 @@
from unittest.mock import Mock
import pytest
from controllers.console.datasets.error import PipelineNotFoundError
from controllers.console.datasets.wraps import get_rag_pipeline
from models.dataset import Pipeline
class TestGetRagPipeline:
def test_missing_pipeline_id(self):
@get_rag_pipeline
def dummy_view(**kwargs):
return "ok"
with pytest.raises(ValueError, match="missing pipeline_id"):
dummy_view()
def test_pipeline_not_found(self, mocker):
@get_rag_pipeline
def dummy_view(**kwargs):
return "ok"
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = None
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
with pytest.raises(PipelineNotFoundError):
dummy_view(pipeline_id="pipeline-1")
def test_pipeline_found_and_injected(self, mocker):
pipeline = Mock(spec=Pipeline)
pipeline.id = "pipeline-1"
pipeline.tenant_id = "tenant-1"
@get_rag_pipeline
def dummy_view(**kwargs):
return kwargs["pipeline"]
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id="pipeline-1")
assert result is pipeline
def test_pipeline_id_removed_from_kwargs(self, mocker):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline
def dummy_view(**kwargs):
assert "pipeline_id" not in kwargs
return "ok"
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id="pipeline-1")
assert result == "ok"
def test_pipeline_id_cast_to_string(self, mocker):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline
def dummy_view(**kwargs):
return kwargs["pipeline"]
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
def where_side_effect(*args, **kwargs):
assert args[0].right.value == "123"
return Mock(first=lambda: pipeline)
mock_query = Mock()
mock_query.where.side_effect = where_side_effect
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id=123)
assert result is pipeline

View File

@ -0,0 +1,341 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailCodeError,
)
from controllers.console.error import AccountInFreezeError
from controllers.console.workspace.account import (
AccountAvatarApi,
AccountDeleteApi,
AccountDeleteVerifyApi,
AccountInitApi,
AccountIntegrateApi,
AccountInterfaceLanguageApi,
AccountInterfaceThemeApi,
AccountNameApi,
AccountPasswordApi,
AccountProfileApi,
AccountTimezoneApi,
ChangeEmailCheckApi,
ChangeEmailResetApi,
CheckEmailUnique,
)
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
InvalidAccountDeletionCodeError,
)
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAccountInitApi:
def test_init_success(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="inactive")
payload = {
"interface_language": "en-US",
"timezone": "UTC",
"invitation_code": "code123",
}
with (
app.test_request_context("/account/init", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.query") as query_mock,
):
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
resp = method(api)
assert resp["result"] == "success"
def test_init_already_initialized(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="active")
with (
app.test_request_context("/account/init"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
):
with pytest.raises(AccountAlreadyInitedError):
method(api)
class TestAccountProfileApi:
def test_get_profile_success(self, app):
api = AccountProfileApi()
method = unwrap(api.get)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/account/profile"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountUpdateApis:
@pytest.mark.parametrize(
("api_cls", "payload"),
[
(AccountNameApi, {"name": "test"}),
(AccountAvatarApi, {"avatar": "img.png"}),
(AccountInterfaceLanguageApi, {"interface_language": "en-US"}),
(AccountInterfaceThemeApi, {"interface_theme": "dark"}),
(AccountTimezoneApi, {"timezone": "UTC"}),
],
)
def test_update_success(self, app, api_cls, payload):
api = api_cls()
method = unwrap(api.post)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountPasswordApi:
def test_password_success(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "old",
"new_password": "new123",
"repeat_new_password": "new123",
}
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
):
result = method(api)
assert result["id"] == "u1"
def test_password_wrong_current(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "bad",
"new_password": "new123",
"repeat_new_password": "new123",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.update_account_password",
side_effect=ServicePwdError(),
),
):
with pytest.raises(CurrentPasswordIncorrectError):
method(api)
class TestAccountIntegrateApi:
def test_get_integrates(self, app):
api = AccountIntegrateApi()
method = unwrap(api.get)
account = MagicMock(id="acc1")
with (
app.test_request_context("/"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
):
scalars_mock.return_value.all.return_value = []
result = method(api)
assert "data" in result
assert len(result["data"]) == 2
class TestAccountDeleteApi:
def test_delete_verify_success(self, app):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
return_value=("token", "1234"),
),
patch(
"controllers.console.workspace.account.AccountService.send_account_deletion_verification_email",
return_value=None,
),
):
result = method(api)
assert result["result"] == "success"
def test_delete_invalid_code(self, app):
api = AccountDeleteApi()
method = unwrap(api.post)
payload = {"token": "t", "code": "x"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
return_value=False,
),
):
with pytest.raises(InvalidAccountDeletionCodeError):
method(api)
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
payload = {"email": "a@test.com", "code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.account.AccountService.get_change_email_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_reset_email_already_used(self, app):
api = ChangeEmailResetApi()
method = unwrap(api.post)
payload = {"new_email": "x@test.com", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
):
with pytest.raises(EmailAlreadyInUseError):
method(api)
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "ok@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True),
):
result = method(api)
assert result["result"] == "success"
def test_email_in_freeze(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "x@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True),
):
with pytest.raises(AccountInFreezeError):
method(api)

View File

@ -0,0 +1,139 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
AgentProviderApi,
AgentProviderListApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAgentProviderListApi:
def test_get_success(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
providers = [{"name": "openai"}, {"name": "anthropic"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=providers,
),
):
result = method(api)
assert result == providers
def test_get_empty_list(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=[],
),
):
result = method(api)
assert result == []
def test_get_account_not_found(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api)
class TestAgentProviderApi:
def test_get_success(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "openai"
provider_data = {"name": "openai", "models": ["gpt-4"]}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=provider_data,
),
):
result = method(api, provider_name)
assert result == provider_data
def test_get_provider_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "unknown"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=None,
),
):
result = method(api, provider_name)
assert result is None
def test_get_account_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api, "openai")

View File

@ -0,0 +1,305 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.workspace.endpoint import (
EndpointCreateApi,
EndpointDeleteApi,
EndpointDisableApi,
EndpointEnableApi,
EndpointListApi,
EndpointListForSinglePluginApi,
EndpointUpdateApi,
)
from core.plugin.impl.exc import PluginPermissionDeniedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user_and_tenant():
return MagicMock(id="u1"), "t1"
@pytest.fixture
def patch_current_account(user_and_tenant):
with patch(
"controllers.console.workspace.endpoint.current_account_with_tenant",
return_value=user_and_tenant,
):
yield
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCreateApi:
def test_create_success(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {"a": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_create_permission_denied(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.endpoint.EndpointService.create_endpoint",
side_effect=PluginPermissionDeniedError("denied"),
),
):
with pytest.raises(ValueError):
method(api)
def test_create_validation_error(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "p1",
"name": "",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
):
result = method(api)
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=0&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
patch(
"controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin",
return_value=[{"id": "e1"}],
),
):
result = method(api)
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDeleteApi:
def test_delete_success(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_delete_invalid_payload(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_delete_service_failure(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointUpdateApi:
def test_update_success(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "new-name",
"settings": {"x": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_update_validation_error(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1", "settings": {}}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
def test_update_service_failure(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "n",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_enable_invalid_payload(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_enable_service_failure(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_disable_invalid_payload(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)

View File

@ -0,0 +1,607 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
MemberNotInTenantError,
NotOwnerError,
OwnerTransferLimitError,
)
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.workspace.members import (
DatasetOperatorMemberListApi,
MemberCancelInviteApi,
MemberInviteEmailApi,
MemberListApi,
MemberUpdateRoleApi,
OwnerTransfer,
OwnerTransferCheckApi,
SendOwnerTransferEmailApi,
)
from services.errors.account import AccountAlreadyInTenantError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestMemberListApi:
def test_get_success(self, app):
api = MemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "m1"
member.name = "Member"
member.email = "member@test.com"
member.avatar = "avatar.png"
member.role = "admin"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = MemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestMemberInviteEmailApi:
def test_invite_success(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
"language": "en-US",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert status == 201
assert result["result"] == "success"
def test_invite_limit_exceeded(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = False
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_already_member(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=AccountAlreadyInTenantError(),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert result["invitation_results"][0]["status"] == "success"
def test_invite_invalid_role(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
payload = {
"emails": ["a@test.com"],
"role": "owner",
}
with app.test_request_context("/", json=payload):
result, status = method(api)
assert status == 400
assert result["code"] == "invalid-role"
def test_invite_generic_exception(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=Exception("boom"),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, _ = method(api)
assert result["invitation_results"][0]["status"] == "failed"
class TestMemberCancelInviteApi:
def test_cancel_success(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 400
def test_cancel_no_permission(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 403
def test_cancel_member_not_in_tenant(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 404
class TestMemberUpdateRoleApi:
def test_update_success(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.update_member_role"),
):
result = method(api, "id")
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
def test_update_invalid_role(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "invalid-role"}
with app.test_request_context("/", json=payload):
result, status = method(api, "id")
assert status == 400
def test_update_member_not_found(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.members.current_account_with_tenant",
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
),
patch("controllers.console.workspace.members.db.session.get", return_value=None),
):
with pytest.raises(HTTPException):
method(api, "id")
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "op1"
member.name = "Operator"
member.email = "operator@test.com"
member.avatar = "avatar.png"
member.role = "operator"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestSendOwnerTransferEmailApi:
def test_send_success(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock(name="ws")
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
),
):
result = method(api)
assert result["result"] == "success"
def test_send_ip_limit(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
):
with pytest.raises(EmailSendIpLimitError):
method(api)
def test_send_not_owner(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/", json={}),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
):
with pytest.raises(NotOwnerError):
method(api)
class TestOwnerTransferCheckApi:
def test_check_invalid_code(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_rate_limited(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=True,
),
):
with pytest.raises(OwnerTransferLimitError):
method(api)
def test_invalid_token(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api)
def test_invalid_email(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "b@test.com", "code": "x"},
),
):
with pytest.raises(InvalidEmailError):
method(api)
class TestOwnerTransferApi:
def test_transfer_self(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
):
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
def test_invalid_token(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
member = MagicMock()
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com"},
),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
):
with pytest.raises(MemberNotInTenantError):
method(api, "2")

View File

@ -0,0 +1,388 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic_core import ValidationError
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.model_providers import (
ModelProviderCredentialApi,
ModelProviderCredentialSwitchApi,
ModelProviderIconApi,
ModelProviderListApi,
ModelProviderPaymentCheckoutUrlApi,
ModelProviderValidateApi,
PreferredProviderTypeUpdateApi,
)
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
INVALID_UUID = "123"
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestModelProviderListApi:
def test_get_success(self, app):
api = ModelProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?model_type=llm"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
return_value=[{"name": "openai"}],
),
):
result = method(api)
assert "data" in result
class TestModelProviderCredentialApi:
def test_get_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(f"/?credential_id={VALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
return_value={"key": "value"},
),
):
result = method(api, provider="openai")
assert "credentials" in result
def test_get_invalid_uuid(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_post_create_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}, "name": "test"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
assert result["result"] == "success"
assert status == 201
def test_post_create_validation_error(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
side_effect=CredentialsValidateFailedError("bad"),
),
):
with pytest.raises(ValueError):
method(api, provider="openai")
def test_put_update_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_put_invalid_uuid(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_delete_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.delete)
payload = {"credential_id": VALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
assert result["result"] == "success"
assert status == 204
class TestModelProviderCredentialSwitchApi:
def test_switch_success(self, app):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
payload = {"credential_id": VALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_switch_invalid_uuid(self, app):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
payload = {"credential_id": INVALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
class TestModelProviderValidateApi:
def test_validate_success(self, app):
api = ModelProviderValidateApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_validate_failure(self, app):
api = ModelProviderValidateApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
side_effect=CredentialsValidateFailedError("bad"),
),
):
result = method(api, provider="openai")
assert result["result"] == "error"
class TestModelProviderIconApi:
def test_icon_success(self, app):
api = ModelProviderIconApi()
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
return_value=(b"123", "image/png"),
),
):
response = api.get("t1", "openai", "logo", "en")
assert response.mimetype == "image/png"
def test_icon_not_found(self, app):
api = ModelProviderIconApi()
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
return_value=(None, None),
),
):
with pytest.raises(ValueError):
api.get("t1", "openai", "logo", "en")
class TestPreferredProviderTypeUpdateApi:
def test_update_success(self, app):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
payload = {"preferred_provider_type": "custom"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_invalid_enum(self, app):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
payload = {"preferred_provider_type": "invalid"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
class TestModelProviderPaymentCheckoutUrlApi:
def test_checkout_success(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
user = MagicMock(id="u1", email="x@test.com")
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
return_value=None,
),
patch(
"controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link",
return_value={"url": "x"},
),
):
result = method(api, provider="anthropic")
assert "url" in result
def test_invalid_provider(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(ValueError):
method(api, provider="openai")
def test_permission_denied(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
user = MagicMock(id="u1", email="x@test.com")
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
side_effect=Forbidden(),
),
):
with pytest.raises(Forbidden):
method(api, provider="anthropic")

View File

@ -0,0 +1,447 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.workspace.models import (
DefaultModelApi,
ModelProviderAvailableModelApi,
ModelProviderModelApi,
ModelProviderModelCredentialApi,
ModelProviderModelCredentialSwitchApi,
ModelProviderModelDisableApi,
ModelProviderModelEnableApi,
ModelProviderModelParameterRuleApi,
ModelProviderModelValidateApi,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDefaultModelApi:
def test_get_success(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.get)
with (
app.test_request_context(
"/",
query_string={"model_type": ModelType.LLM.value},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
result = method(api)
assert "data" in result
def test_post_success(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.post)
payload = {
"model_settings": [
{
"model_type": ModelType.LLM.value,
"provider": "openai",
"model": "gpt-4",
}
]
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api)
assert result["result"] == "success"
def test_get_returns_empty_when_no_default(self, app):
api = DefaultModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_default_model_of_model_type.return_value = None
result = method(api)
assert "data" in result
class TestModelProviderModelApi:
def test_get_models_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_provider.return_value = []
result = method(api, "openai")
assert "data" in result
def test_post_models_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"load_balancing": {
"configs": [{"weight": 1}],
"enabled": True,
},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
):
result, status = method(api, "openai")
assert status == 200
def test_delete_model_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.delete)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 204
def test_get_models_returns_empty(self, app):
api = ModelProviderModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_provider.return_value = []
result = method(api, "openai")
assert "data" in result
class TestModelProviderModelCredentialApi:
def test_get_credentials_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(
"/",
query_string={
"model": "gpt-4",
"model_type": ModelType.LLM.value,
},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
):
provider_service.return_value.get_model_credential.return_value = {
"credentials": {},
"current_credential_id": None,
"current_credential_name": None,
}
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "openai")
assert "credentials" in result
def test_create_credential_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credentials": {"key": "val"},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 201
def test_get_empty_credentials(self, app):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
):
service.return_value.get_model_credential.return_value = None
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "openai")
assert result["credentials"] == {}
def test_delete_success(self, app):
api = ModelProviderModelCredentialApi()
method = unwrap(api.delete)
payload = {
"model": "gpt",
"model_type": ModelType.LLM.value,
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 204
class TestModelProviderModelCredentialSwitchApi:
def test_switch_success(self, app: Flask):
api = ModelProviderModelCredentialSwitchApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credential_id": "abc",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
class TestModelEnableDisableApis:
def test_enable_model(self, app: Flask):
api = ModelProviderModelEnableApi()
method = unwrap(api.patch)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
def test_disable_model(self, app: Flask):
api = ModelProviderModelDisableApi()
method = unwrap(api.patch)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
class TestModelProviderModelValidateApi:
def test_validate_success(self, app: Flask):
api = ModelProviderModelValidateApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credentials": {"key": "val"},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
@pytest.mark.parametrize("model_name", ["gpt-4", "gpt"])
def test_validate_failure(self, app: Flask, model_name: str):
api = ModelProviderModelValidateApi()
method = unwrap(api.post)
payload = {
"model": model_name,
"model_type": ModelType.LLM.value,
"credentials": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
result = method(api, "openai")
assert result["result"] == "error"
class TestParameterAndAvailableModels:
def test_parameter_rules(self, app: Flask):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt-4"}),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_model_parameter_rules.return_value = []
result = method(api, "openai")
assert "data" in result
def test_available_models(self, app: Flask):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
assert "data" in result
def test_empty_rules(self, app):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt"}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_model_parameter_rules.return_value = []
result = method(api, "openai")
assert result["data"] == []
def test_no_models(self, app):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
assert result["data"] == []

File diff suppressed because it is too large Load Diff

View File

@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from controllers.console.workspace.tool_providers import (
ToolApiListApi,
ToolApiProviderAddApi,
ToolApiProviderDeleteApi,
ToolApiProviderGetApi,
ToolApiProviderGetRemoteSchemaApi,
ToolApiProviderListToolsApi,
ToolApiProviderUpdateApi,
ToolBuiltinListApi,
ToolBuiltinProviderAddApi,
ToolBuiltinProviderCredentialsSchemaApi,
ToolBuiltinProviderDeleteApi,
ToolBuiltinProviderGetCredentialInfoApi,
ToolBuiltinProviderGetCredentialsApi,
ToolBuiltinProviderGetOauthClientSchemaApi,
ToolBuiltinProviderIconApi,
ToolBuiltinProviderInfoApi,
ToolBuiltinProviderListToolsApi,
ToolBuiltinProviderSetDefaultApi,
ToolBuiltinProviderUpdateApi,
ToolLabelsApi,
ToolOAuthCallback,
ToolOAuthCustomClient,
ToolPluginOAuthApi,
ToolProviderListApi,
ToolProviderMCPApi,
ToolWorkflowListApi,
ToolWorkflowProviderCreateApi,
ToolWorkflowProviderDeleteApi,
ToolWorkflowProviderGetApi,
ToolWorkflowProviderUpdateApi,
is_valid_url,
)
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def _mock_cache():
return
@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]
class TestUtils:
def test_is_valid_url(self):
assert is_valid_url("https://example.com")
assert is_valid_url("http://example.com")
assert not is_valid_url("")
assert not is_valid_url("ftp://example.com")
assert not is_valid_url("not-a-url")
assert not is_valid_url(None)
class TestToolProviderListApi:
def test_get_success(self, app):
api = ToolProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
),
patch(
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
return_value=["p1"],
),
):
assert method(api) == ["p1"]
class TestBuiltinProviderApis:
def test_list_tools(self, app):
api = ToolBuiltinProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
return_value=[{"a": 1}],
),
):
assert method(api, "provider") == [{"a": 1}]
def test_info(self, app):
api = ToolBuiltinProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
return_value={"x": 1},
),
):
assert method(api, "provider") == {"x": 1}
def test_delete(self, app):
api = ToolBuiltinProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_id": "cid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
return_value={"result": "success"},
),
):
assert method(api, "provider")["result"] == "success"
def test_add_invalid_type(self, app):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
):
with pytest.raises(ValueError):
method(api, "provider")
def test_add_success(self, app):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
return_value={"id": 1},
),
):
assert method(api, "provider")["id"] == 1
def test_update(self, app):
api = ToolBuiltinProviderUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_credentials(self, app):
api = ToolBuiltinProviderGetCredentialsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
return_value={"k": "v"},
),
):
assert method(api, "provider") == {"k": "v"}
def test_icon(self, app):
api = ToolBuiltinProviderIconApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon",
return_value=(b"x", "image/png"),
),
):
response = method(api, "provider")
assert response.mimetype == "image/png"
def test_credentials_schema(self, app):
api = ToolBuiltinProviderCredentialsSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider", "oauth2") == {"schema": {}}
def test_set_default_credential(self, app):
api = ToolBuiltinProviderSetDefaultApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"id": "c1"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_credential_info(self, app):
api = ToolBuiltinProviderGetCredentialInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
return_value={"info": "x"},
),
):
assert method(api, "provider") == {"info": "x"}
def test_get_oauth_client_schema(self, app):
api = ToolBuiltinProviderGetOauthClientSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider") == {"schema": {}}
class TestApiProviderApis:
def test_add(self, app):
api = ToolApiProviderAddApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"icon": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
def test_remote_schema(self, app):
api = ToolApiProviderGetRemoteSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/?url=http://x.com"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
return_value={"schema": "x"},
),
):
assert method(api)["schema"] == "x"
def test_list_tools(self, app):
api = ToolApiProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
return_value=[{"tool": 1}],
),
):
assert method(api) == [{"tool": 1}]
def test_update(self, app):
api = ToolApiProviderUpdateApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"original_provider": "o",
"icon": {},
"privacy_policy": "",
"custom_disclaimer": "",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
return_value={"ok": True},
),
):
assert method(api)["ok"]
def test_delete(self, app):
api = ToolApiProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"provider": "p"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
return_value={"result": "success"},
),
):
assert method(api)["result"] == "success"
def test_get(self, app):
api = ToolApiProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
return_value={"x": 1},
),
):
assert method(api) == {"x": 1}
class TestWorkflowApis:
def test_create(self, app):
api = ToolWorkflowProviderCreateApi()
method = unwrap(api.post)
payload = {
"workflow_app_id": "123e4567-e89b-12d3-a456-426614174000",
"name": "n",
"label": "l",
"description": "d",
"icon": {},
"parameters": [],
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
def test_update_invalid(self, app):
api = ToolWorkflowProviderUpdateApi()
method = unwrap(api.post)
payload = {
"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000",
"name": "Tool",
"label": "Tool Label",
"description": "A tool",
"icon": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
return_value={"ok": True},
),
):
result = method(api)
assert result["ok"]
def test_delete(self, app):
api = ToolWorkflowProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
return_value={"ok": True},
),
):
assert method(api)["ok"]
def test_get_error(self, app):
api = ToolWorkflowProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
):
with pytest.raises(ValueError):
method(api)
class TestLists:
def test_builtin_list(self, app):
api = ToolBuiltinListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
def test_api_list(self, app):
api = ToolApiListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
def test_workflow_list(self, app):
api = ToolWorkflowListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
class TestLabels:
def test_labels(self, app):
api = ToolLabelsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
return_value=["l1"],
),
):
assert method(api) == ["l1"]
class TestOAuth:
def test_oauth_no_client(self, app):
api = ToolPluginOAuthApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "provider")
def test_oauth_callback_no_cookie(self, app):
api = ToolOAuthCallback()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "provider")
class TestOAuthCustomClient:
def test_save_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"client_params": {"a": 1}}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
return_value={"client_id": "x"},
),
):
assert method(api, "provider") == {"client_id": "x"}
def test_delete_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]

View File

@ -0,0 +1,558 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden
from controllers.console.workspace.trigger_providers import (
TriggerOAuthAuthorizeApi,
TriggerOAuthCallbackApi,
TriggerOAuthClientManageApi,
TriggerProviderIconApi,
TriggerProviderInfoApi,
TriggerProviderListApi,
TriggerSubscriptionBuilderBuildApi,
TriggerSubscriptionBuilderCreateApi,
TriggerSubscriptionBuilderGetApi,
TriggerSubscriptionBuilderLogsApi,
TriggerSubscriptionBuilderUpdateApi,
TriggerSubscriptionBuilderVerifyApi,
TriggerSubscriptionDeleteApi,
TriggerSubscriptionListApi,
TriggerSubscriptionUpdateApi,
TriggerSubscriptionVerifyApi,
)
from controllers.web.error import NotFoundError
from core.plugin.entities.plugin_daemon import CredentialType
from models.account import Account
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def mock_user():
user = MagicMock(spec=Account)
user.id = "u1"
user.current_tenant_id = "t1"
return user
class TestTriggerProviderApis:
def test_icon_success(self, app):
api = TriggerProviderIconApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
return_value="icon",
),
):
assert method(api, "github") == "icon"
def test_list_providers(self, app):
api = TriggerProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
return_value=[],
),
):
assert method(api) == []
def test_provider_info(self, app):
api = TriggerProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
return_value={"id": "p1"},
),
):
assert method(api, "github") == {"id": "p1"}
class TestTriggerSubscriptionListApi:
def test_list_success(self, app):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
return_value=[],
),
):
assert method(api, "github") == []
def test_list_invalid_provider(self, app):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
side_effect=ValueError("bad"),
),
):
result, status = method(api, "bad")
assert status == 404
class TestTriggerSubscriptionBuilderApis:
def test_create_builder(self, app):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
result = method(api, "github")
assert "subscription_builder" in result
def test_get_builder(self, app):
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
return_value={"id": "b1"},
),
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {"a": 1}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
return_value={"ok": True},
),
):
assert method(api, "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
side_effect=Exception("err"),
),
):
with pytest.raises(ValueError):
method(api, "github", "b1")
def test_update_builder(self, app):
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "n"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_logs(self, app):
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
log = MagicMock()
log.model_dump.return_value = {"a": 1}
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
return_value=[log],
),
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app):
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
return_value=None,
),
):
assert method(api, "github", "b1") == 200
class TestTriggerSubscriptionCrud:
def test_update_rename_only(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
sub = MagicMock()
sub.provider_id = "github"
sub.credential_type = CredentialType.UNAUTHORIZED
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
),
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
):
assert method(api, "s1") == 200
def test_update_not_found(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "x")
def test_update_rebuild(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
sub = MagicMock()
sub.provider_id = "github"
sub.credential_type = CredentialType.OAUTH2
sub.credentials = {}
sub.parameters = {}
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
),
):
assert method(api, "s1") == 200
def test_delete_subscription(self, app):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
mock_session = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls,
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
),
):
mock_db.engine = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
result = method(api, "sub1")
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as session_cls,
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
side_effect=ValueError("bad"),
),
):
mock_db.engine = MagicMock()
session_cls.return_value.__enter__.return_value = MagicMock()
with pytest.raises(BadRequest):
method(api, "sub1")
class TestTriggerOAuthApis:
def test_oauth_authorize_success(self, app):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value=MagicMock(id="b1"),
),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
return_value="ctx",
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url",
return_value=MagicMock(authorization_url="url"),
),
):
resp = method(api, "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "github")
def test_oauth_callback_forbidden(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
return_value=MagicMock(credentials={"a": 1}, expires_at=1),
),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder"
),
):
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
return_value=ctx,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
return_value=ctx,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
return_value=MagicMock(credentials=None, expires_at=None),
),
):
with pytest.raises(ValueError):
method(api, "github")
class TestTriggerOAuthClientManageApi:
def test_get_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
return_value={},
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled",
return_value=False,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists",
return_value=True,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
),
):
result = method(api, "github")
assert "configured" in result
def test_post_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
def test_delete_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
side_effect=ValueError("bad"),
),
):
with pytest.raises(BadRequest):
method(api, "github")
class TestTriggerSubscriptionVerifyApi:
def test_verify_success(self, app):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
return_value={"ok": True},
),
):
assert method(api, "github", "s1") == {"ok": True}
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
def test_verify_errors(self, app, raised_exception):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
side_effect=raised_exception,
),
):
with pytest.raises(BadRequest):
method(api, "github", "s1")

View File

@ -0,0 +1,605 @@
from datetime import datetime
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Unauthorized
import services
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.workspace.workspace import (
CustomConfigWorkspaceApi,
SwitchWorkspaceApi,
TenantApi,
TenantListApi,
WebappLogoWorkspaceApi,
WorkspaceInfoApi,
WorkspaceListApi,
WorkspacePermissionApi,
)
from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestTenantListApi:
def test_get_success(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
features = MagicMock()
features.billing.enabled = True
features.billing.subscription.plan = CloudPlan.SANDBOX
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
def test_get_billing_disabled(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant = MagicMock(
id="t1",
name="Tenant",
status="active",
created_at=datetime.utcnow(),
)
features = MagicMock()
features.billing.enabled = False
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant],
),
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
),
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
class TestWorkspaceListApi:
def test_get_success(self, app):
api = WorkspaceListApi()
method = unwrap(api.get)
tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow())
paginate_result = MagicMock(
items=[tenant],
has_next=False,
total=1,
)
with (
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}),
patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result),
):
result, status = method(api)
assert status == 200
assert result["total"] == 1
assert result["has_more"] is False
def test_get_has_next_true(self, app):
api = WorkspaceListApi()
method = unwrap(api.get)
tenant = MagicMock(
id="t1",
name="T",
status="active",
created_at=datetime.utcnow(),
)
paginate_result = MagicMock(
items=[tenant],
has_next=True,
total=10,
)
with (
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}),
patch(
"controllers.console.workspace.workspace.db.paginate",
return_value=paginate_result,
),
):
result, status = method(api)
assert status == 200
assert result["has_more"] is True
class TestTenantApi:
def test_post_active_tenant(self, app):
api = TenantApi()
method = unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result, status = method(api)
assert status == 200
assert result["id"] == "t1"
def test_post_archived_with_switch(self, app):
api = TenantApi()
method = unwrap(api.post)
archived = MagicMock(status=TenantStatus.ARCHIVE)
new_tenant = MagicMock(status="active")
user = MagicMock(current_tenant=archived)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
),
):
result, status = method(api)
assert result["id"] == "new"
def test_post_archived_no_tenant(self, app):
api = TenantApi()
method = unwrap(api.post)
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
):
with pytest.raises(Unauthorized):
method(api)
def test_post_info_path(self, app):
api = TenantApi()
method = unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/info"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(user, "t1"),
),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"id": "t1"},
),
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
):
result, status = method(api)
warn_mock.assert_called_once()
assert status == 200
class TestSwitchWorkspaceApi:
def test_switch_success(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "t2"}
tenant = MagicMock(id="t2")
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
),
):
query_mock.return_value.get.return_value = tenant
result = method(api)
assert result["result"] == "success"
def test_switch_not_linked(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "bad"}
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
):
with pytest.raises(AccountNotLinkTenantError):
method(api)
def test_switch_tenant_not_found(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "missing"}
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
):
query_mock.return_value.get.return_value = None
with pytest.raises(ValueError):
method(api)
class TestCustomConfigWorkspaceApi:
def test_post_success(self, app):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
tenant = MagicMock(custom_config_dict={})
payload = {"remove_webapp_brand": True}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result = method(api)
assert result["result"] == "success"
def test_logo_fallback(self, app):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
payload = {"remove_webapp_brand": False}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.db.get_or_404",
return_value=tenant,
),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"id": "t1"},
),
):
result = method(api)
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
assert result["result"] == "success"
class TestWebappLogoWorkspaceApi:
def test_no_file(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
with (
app.test_request_context("/upload", data={}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with pytest.raises(NoFileUploadedError):
method(api)
def test_too_many_files(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
data = {
"file": MagicMock(),
"extra": MagicMock(),
}
with (
app.test_request_context("/upload", data=data),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
):
with pytest.raises(TooManyFilesError):
method(api)
def test_invalid_extension(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = MagicMock(filename="test.txt")
with (
app.test_request_context("/upload", data={"file": file}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with pytest.raises(UnsupportedFileTypeError):
method(api)
def test_upload_success(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
filename="logo.png",
content_type="image/png",
)
upload = MagicMock(id="file1")
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.return_value = upload
result, status = method(api)
assert status == 201
assert result["id"] == "file1"
def test_filename_missing(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
filename="",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
):
with pytest.raises(FilenameNotExistsError):
method(api)
def test_file_too_large(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
with pytest.raises(FileTooLargeError):
method(api)
def test_service_unsupported_file(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
with pytest.raises(UnsupportedFileTypeError):
method(api)
class TestWorkspaceInfoApi:
def test_post_success(self, app):
api = WorkspaceInfoApi()
method = unwrap(api.post)
tenant = MagicMock()
payload = {"name": "New Name"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"name": "New Name"},
),
):
result = method(api)
assert result["result"] == "success"
def test_no_current_tenant(self, app):
api = WorkspaceInfoApi()
method = unwrap(api.post)
payload = {"name": "X"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(ValueError):
method(api)
class TestWorkspacePermissionApi:
def test_get_success(self, app):
api = WorkspacePermissionApi()
method = unwrap(api.get)
permission = MagicMock(
workspace_id="t1",
allow_member_invite=True,
allow_owner_transfer=False,
)
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
return_value=permission,
),
):
result, status = method(api)
assert status == 200
assert result["workspace_id"] == "t1"
def test_no_current_tenant(self, app):
api = WorkspacePermissionApi()
method = unwrap(api.get)
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(ValueError):
method(api)

View File

@ -0,0 +1,142 @@
from __future__ import annotations
import importlib
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.workspace import plugin_permission_required
from models.account import TenantPluginPermission
class _SessionStub:
def __init__(self, permission):
self._permission = permission
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def query(self, *_args, **_kwargs):
return self
def where(self, *_args, **_kwargs):
return self
def first(self):
return self._permission
def _workspace_module():
return importlib.import_module(plugin_permission_required.__module__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
module = _workspace_module()
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, None)
@plugin_permission_required()
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()