Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox

# Conflicts:
#	api/core/app/apps/workflow/app_generator.py
This commit is contained in:
yyh
2026-01-19 16:32:43 +08:00
81 changed files with 8164 additions and 1193 deletions

View File

@ -1,254 +0,0 @@
"""
Unit tests for XSS prevention in App payloads.
This test module validates that HTML tags, JavaScript, and other potentially
dangerous content are rejected in App names and descriptions.
"""
import pytest
from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload
class TestXSSPreventionUnit:
"""Unit tests for XSS prevention in App payloads."""
def test_create_app_valid_names(self):
"""Test CreateAppPayload with valid app names."""
# Normal app names should be valid
valid_names = [
"My App",
"Test App 123",
"App with - dash",
"App with _ underscore",
"App with + plus",
"App with () parentheses",
"App with [] brackets",
"App with {} braces",
"App with ! exclamation",
"App with @ at",
"App with # hash",
"App with $ dollar",
"App with % percent",
"App with ^ caret",
"App with & ampersand",
"App with * asterisk",
"Unicode: 测试应用",
"Emoji: 🤖",
"Mixed: Test 测试 123",
]
for name in valid_names:
payload = CreateAppPayload(
name=name,
mode="chat",
)
assert payload.name == name
def test_create_app_xss_script_tags(self):
"""Test CreateAppPayload rejects script tags."""
xss_payloads = [
"<script>alert(document.cookie)</script>",
"<Script>alert(1)</Script>",
"<SCRIPT>alert('XSS')</SCRIPT>",
"<script>alert(String.fromCharCode(88,83,83))</script>",
"<script src='evil.js'></script>",
"<script>document.location='http://evil.com'</script>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_iframe_tags(self):
"""Test CreateAppPayload rejects iframe tags."""
xss_payloads = [
"<iframe src='evil.com'></iframe>",
"<Iframe srcdoc='<script>alert(1)</script>'></iframe>",
"<IFRAME src='javascript:alert(1)'></iframe>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_javascript_protocol(self):
"""Test CreateAppPayload rejects javascript: protocol."""
xss_payloads = [
"javascript:alert(1)",
"JAVASCRIPT:alert(1)",
"JavaScript:alert(document.cookie)",
"javascript:void(0)",
"javascript://comment%0Aalert(1)",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_svg_onload(self):
"""Test CreateAppPayload rejects SVG with onload."""
xss_payloads = [
"<svg onload=alert(1)>",
"<SVG ONLOAD=alert(1)>",
"<svg/x/onload=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_event_handlers(self):
"""Test CreateAppPayload rejects HTML event handlers."""
xss_payloads = [
"<div onclick=alert(1)>",
"<img onerror=alert(1)>",
"<body onload=alert(1)>",
"<input onfocus=alert(1)>",
"<a onmouseover=alert(1)>",
"<DIV ONCLICK=alert(1)>",
"<img src=x onerror=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_object_embed(self):
"""Test CreateAppPayload rejects object and embed tags."""
xss_payloads = [
"<object data='evil.swf'></object>",
"<embed src='evil.swf'>",
"<OBJECT data='javascript:alert(1)'></OBJECT>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_link_javascript(self):
"""Test CreateAppPayload rejects link tags with javascript."""
xss_payloads = [
"<link href='javascript:alert(1)'>",
"<LINK HREF='javascript:alert(1)'>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_in_description(self):
"""Test CreateAppPayload rejects XSS in description."""
xss_descriptions = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for description in xss_descriptions:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(
name="Valid Name",
mode="chat",
description=description,
)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_valid_descriptions(self):
"""Test CreateAppPayload with valid descriptions."""
valid_descriptions = [
"A simple description",
"Description with < and > symbols",
"Description with & ampersand",
"Description with 'quotes' and \"double quotes\"",
"Description with / slashes",
"Description with \\ backslashes",
"Description with ; semicolons",
"Unicode: 这是一个描述",
"Emoji: 🎉🚀",
]
for description in valid_descriptions:
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=description,
)
assert payload.description == description
def test_create_app_none_description(self):
"""Test CreateAppPayload with None description."""
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=None,
)
assert payload.description is None
def test_update_app_xss_prevention(self):
"""Test UpdateAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
UpdateAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_update_app_valid_names(self):
"""Test UpdateAppPayload with valid names."""
payload = UpdateAppPayload(name="Valid Updated Name")
assert payload.name == "Valid Updated Name"
def test_copy_app_xss_prevention(self):
"""Test CopyAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
CopyAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_copy_app_valid_names(self):
"""Test CopyAppPayload with valid names."""
payload = CopyAppPayload(name="Valid Copy Name")
assert payload.name == "Valid Copy Name"
def test_copy_app_none_name(self):
"""Test CopyAppPayload with None name (should be allowed)."""
payload = CopyAppPayload(name=None)
assert payload.name is None
def test_edge_case_angle_brackets_content(self):
"""Test that angle brackets with actual content are rejected."""
# Angle brackets without valid HTML-like patterns should be checked
# The regex pattern <.*?on\w+\s*= should catch event handlers
# But let's verify other patterns too
# Valid: angle brackets used as symbols (not matched by our patterns)
# Our patterns specifically look for dangerous constructs
# Invalid: actual HTML tags with event handlers
invalid_names = [
"<div onclick=xss>",
"<img src=x onerror=alert(1)>",
]
for name in invalid_names:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()

View File

@ -346,6 +346,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeRateLimitError",
"message": "Rate limit exceeded",
"args": {"description": "Rate limit exceeded"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@ -364,6 +365,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeAuthorizationError",
"message": "Invalid credentials",
"args": {"description": "Invalid credentials"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@ -382,6 +384,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeBadRequestError",
"message": "Invalid parameters",
"args": {"description": "Invalid parameters"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@ -400,6 +403,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeConnectionError",
"message": "Connection to external service failed",
"args": {"description": "Connection to external service failed"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@ -418,6 +422,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeServerUnavailableError",
"message": "Service temporarily unavailable",
"args": {"description": "Service temporarily unavailable"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})

View File

@ -0,0 +1 @@
"""Tests for workflow context management."""

View File

@ -0,0 +1,258 @@
"""Tests for execution context module."""
import contextvars
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.workflow.context.execution_context import (
AppContext,
ExecutionContext,
ExecutionContextBuilder,
IExecutionContext,
NullAppContext,
)
class TestAppContext:
"""Test AppContext abstract base class."""
def test_app_context_is_abstract(self):
"""Test that AppContext cannot be instantiated directly."""
with pytest.raises(TypeError):
AppContext() # type: ignore
class TestNullAppContext:
"""Test NullAppContext implementation."""
def test_null_app_context_get_config(self):
"""Test get_config returns value from config dict."""
config = {"key1": "value1", "key2": "value2"}
ctx = NullAppContext(config=config)
assert ctx.get_config("key1") == "value1"
assert ctx.get_config("key2") == "value2"
def test_null_app_context_get_config_default(self):
"""Test get_config returns default when key not found."""
ctx = NullAppContext()
assert ctx.get_config("nonexistent", "default") == "default"
assert ctx.get_config("nonexistent") is None
def test_null_app_context_get_extension(self):
"""Test get_extension returns stored extension."""
ctx = NullAppContext()
extension = MagicMock()
ctx.set_extension("db", extension)
assert ctx.get_extension("db") == extension
def test_null_app_context_get_extension_not_found(self):
"""Test get_extension returns None when extension not found."""
ctx = NullAppContext()
assert ctx.get_extension("nonexistent") is None
def test_null_app_context_enter_yield(self):
"""Test enter method yields without any side effects."""
ctx = NullAppContext()
with ctx.enter():
# Should not raise any exception
pass
class TestExecutionContext:
"""Test ExecutionContext class."""
def test_initialization_with_all_params(self):
"""Test ExecutionContext initialization with all parameters."""
app_ctx = NullAppContext()
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = ExecutionContext(
app_context=app_ctx,
context_vars=context_vars,
user=user,
)
assert ctx.app_context == app_ctx
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_initialization_with_minimal_params(self):
"""Test ExecutionContext initialization with minimal parameters."""
ctx = ExecutionContext()
assert ctx.app_context is None
assert ctx.context_vars is None
assert ctx.user is None
def test_enter_with_context_vars(self):
"""Test enter restores context variables."""
test_var = contextvars.ContextVar("test_var")
test_var.set("original_value")
# Copy context with the variable
context_vars = contextvars.copy_context()
# Change the variable
test_var.set("new_value")
# Create execution context and enter it
ctx = ExecutionContext(context_vars=context_vars)
with ctx.enter():
# Variable should be restored to original value
assert test_var.get() == "original_value"
# After exiting, variable stays at the value from within the context
# (this is expected Python contextvars behavior)
assert test_var.get() == "original_value"
def test_enter_with_app_context(self):
"""Test enter enters app context if available."""
app_ctx = NullAppContext()
ctx = ExecutionContext(app_context=app_ctx)
# Should not raise any exception
with ctx.enter():
pass
def test_enter_without_app_context(self):
"""Test enter works without app context."""
ctx = ExecutionContext(app_context=None)
# Should not raise any exception
with ctx.enter():
pass
def test_context_manager_protocol(self):
"""Test ExecutionContext supports context manager protocol."""
ctx = ExecutionContext()
with ctx:
# Should not raise any exception
pass
def test_user_property(self):
"""Test user property returns set user."""
user = MagicMock()
ctx = ExecutionContext(user=user)
assert ctx.user == user
class TestIExecutionContextProtocol:
"""Test IExecutionContext protocol."""
def test_execution_context_implements_protocol(self):
"""Test that ExecutionContext implements IExecutionContext protocol."""
ctx = ExecutionContext()
# Should have __enter__ and __exit__ methods
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
assert hasattr(ctx, "user")
def test_protocol_compatibility(self):
"""Test that ExecutionContext can be used where IExecutionContext is expected."""
def accept_context(context: IExecutionContext) -> Any:
"""Function that accepts IExecutionContext protocol."""
# Just verify it has the required protocol attributes
assert hasattr(context, "__enter__")
assert hasattr(context, "__exit__")
assert hasattr(context, "user")
return context.user
ctx = ExecutionContext(user="test_user")
result = accept_context(ctx)
assert result == "test_user"
def test_protocol_with_flask_execution_context(self):
"""Test that IExecutionContext protocol is compatible with different implementations."""
# Verify the protocol works with ExecutionContext
ctx = ExecutionContext(user="test_user")
# Should have the required protocol attributes
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
assert hasattr(ctx, "user")
assert ctx.user == "test_user"
# Should work as context manager
with ctx:
assert ctx.user == "test_user"
class TestExecutionContextBuilder:
"""Test ExecutionContextBuilder class."""
def test_builder_with_all_params(self):
"""Test builder with all parameters set."""
app_ctx = NullAppContext()
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = (
ExecutionContextBuilder().with_app_context(app_ctx).with_context_vars(context_vars).with_user(user).build()
)
assert ctx.app_context == app_ctx
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_builder_with_partial_params(self):
"""Test builder with only some parameters set."""
app_ctx = NullAppContext()
ctx = ExecutionContextBuilder().with_app_context(app_ctx).build()
assert ctx.app_context == app_ctx
assert ctx.context_vars is None
assert ctx.user is None
def test_builder_fluent_interface(self):
"""Test builder provides fluent interface."""
builder = ExecutionContextBuilder()
# Each method should return the builder
assert isinstance(builder.with_app_context(NullAppContext()), ExecutionContextBuilder)
assert isinstance(builder.with_context_vars(contextvars.copy_context()), ExecutionContextBuilder)
assert isinstance(builder.with_user(None), ExecutionContextBuilder)
class TestCaptureCurrentContext:
"""Test capture_current_context function."""
def test_capture_current_context_returns_context(self):
"""Test that capture_current_context returns a valid context."""
from core.workflow.context.execution_context import capture_current_context
result = capture_current_context()
# Should return an object that implements IExecutionContext
assert hasattr(result, "__enter__")
assert hasattr(result, "__exit__")
assert hasattr(result, "user")
def test_capture_current_context_captures_contextvars(self):
"""Test that capture_current_context captures context variables."""
# Set a context variable before capturing
import contextvars
test_var = contextvars.ContextVar("capture_test_var")
test_var.set("test_value_123")
from core.workflow.context.execution_context import capture_current_context
result = capture_current_context()
# Context variables should be captured
assert result.context_vars is not None

View File

@ -0,0 +1,316 @@
"""Tests for Flask app context module."""
import contextvars
from unittest.mock import MagicMock, patch
import pytest
class TestFlaskAppContext:
"""Test FlaskAppContext implementation."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app."""
app = MagicMock()
app.config = {"TEST_KEY": "test_value"}
app.extensions = {"db": MagicMock(), "cache": MagicMock()}
app.app_context = MagicMock()
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return app
def test_flask_app_context_initialization(self, mock_flask_app):
"""Test FlaskAppContext initialization."""
# Import here to avoid Flask dependency in test environment
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.flask_app == mock_flask_app
def test_flask_app_context_get_config(self, mock_flask_app):
"""Test get_config returns Flask app config value."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_config("TEST_KEY") == "test_value"
def test_flask_app_context_get_config_default(self, mock_flask_app):
"""Test get_config returns default when key not found."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_config("NONEXISTENT", "default") == "default"
def test_flask_app_context_get_extension(self, mock_flask_app):
"""Test get_extension returns Flask extension."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
db_ext = mock_flask_app.extensions["db"]
assert ctx.get_extension("db") == db_ext
def test_flask_app_context_get_extension_not_found(self, mock_flask_app):
"""Test get_extension returns None when extension not found."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_extension("nonexistent") is None
def test_flask_app_context_enter(self, mock_flask_app):
"""Test enter method enters Flask app context."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
with ctx.enter():
# Should not raise any exception
pass
# Verify app_context was called
mock_flask_app.app_context.assert_called_once()
class TestFlaskExecutionContext:
"""Test FlaskExecutionContext class."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app."""
app = MagicMock()
app.config = {}
app.app_context = MagicMock()
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return app
def test_initialization(self, mock_flask_app):
"""Test FlaskExecutionContext initialization."""
from context.flask_app_context import FlaskExecutionContext
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=context_vars,
user=user,
)
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_app_context_property(self, mock_flask_app):
"""Test app_context property returns FlaskAppContext."""
from context.flask_app_context import FlaskAppContext, FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
assert isinstance(ctx.app_context, FlaskAppContext)
assert ctx.app_context.flask_app == mock_flask_app
def test_context_manager_protocol(self, mock_flask_app):
"""Test FlaskExecutionContext supports context manager protocol."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
# Should have __enter__ and __exit__ methods
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
# Should work as context manager
with ctx:
pass
class TestCaptureFlaskContext:
"""Test capture_flask_context function."""
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
def test_capture_flask_context_captures_app(self, mock_g, mock_current_app):
"""Test capture_flask_context captures Flask app."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
assert ctx._flask_app == mock_app
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app):
"""Test capture_flask_context captures user from Flask g object."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
mock_user = MagicMock()
mock_user.id = "user_123"
mock_g._login_user = mock_user
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
assert ctx.user == mock_user
@patch("context.flask_app_context.current_app")
def test_capture_flask_context_with_explicit_user(self, mock_current_app):
"""Test capture_flask_context uses explicit user parameter."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
explicit_user = MagicMock()
explicit_user.id = "user_456"
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context(user=explicit_user)
assert ctx.user == explicit_user
@patch("context.flask_app_context.current_app")
def test_capture_flask_context_captures_contextvars(self, mock_current_app):
"""Test capture_flask_context captures context variables."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
# Set a context variable
test_var = contextvars.ContextVar("test_var")
test_var.set("test_value")
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
# Context variables should be captured
assert ctx.context_vars is not None
# Verify the variable is in the captured context
captured_value = ctx.context_vars[test_var]
assert captured_value == "test_value"
class TestFlaskExecutionContextIntegration:
"""Integration tests for FlaskExecutionContext."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app with proper app context."""
app = MagicMock()
app.config = {"TEST": "value"}
app.extensions = {"db": MagicMock()}
# Mock app context
mock_app_context = MagicMock()
mock_app_context.__enter__ = MagicMock(return_value=None)
mock_app_context.__exit__ = MagicMock(return_value=None)
app.app_context.return_value = mock_app_context
return app
def test_enter_restores_context_vars(self, mock_flask_app):
"""Test that enter restores captured context variables."""
# Create a context variable and set a value
test_var = contextvars.ContextVar("integration_test_var")
test_var.set("original_value")
# Capture the context
context_vars = contextvars.copy_context()
# Change the value
test_var.set("new_value")
# Create FlaskExecutionContext and enter it
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=context_vars,
)
with ctx:
# Value should be restored to original
assert test_var.get() == "original_value"
# After exiting, variable stays at the value from within the context
# (this is expected Python contextvars behavior)
assert test_var.get() == "original_value"
def test_enter_enters_flask_app_context(self, mock_flask_app):
"""Test that enter enters Flask app context."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
with ctx:
# Verify app context was entered
assert mock_flask_app.app_context.called
@patch("context.flask_app_context.g")
def test_enter_restores_user_in_g(self, mock_g, mock_flask_app):
"""Test that enter restores user in Flask g object."""
mock_user = MagicMock()
mock_user.id = "test_user"
# Note: FlaskExecutionContext saves user from g before entering context,
# then restores it after entering the app context.
# The user passed to constructor is NOT restored to g.
# So we need to test the actual behavior.
# Create FlaskExecutionContext with user in constructor
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
user=mock_user,
)
# Set user in g before entering (simulating existing user in g)
mock_g._login_user = mock_user
with ctx:
# After entering, the user from g before entry should be restored
assert mock_g._login_user == mock_user
# The user in constructor is stored but not automatically restored to g
# (it's available via ctx.user property)
assert ctx.user == mock_user
def test_enter_method_as_context_manager(self, mock_flask_app):
"""Test enter method returns a proper context manager."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
# enter() should return a generator/context manager
with ctx.enter():
# Should work without issues
pass
# Verify app context was called
assert mock_flask_app.app_context.called

View File

@ -0,0 +1,142 @@
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.workspace_permission import (
check_workspace_member_invite_permission,
check_workspace_owner_transfer_permission,
)
class TestWorkspacePermissionHelper:
"""Test workspace permission helper functions."""
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.EnterpriseService")
def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config):
"""Community edition should always allow invitations without calling any service."""
mock_config.ENTERPRISE_ENABLED = False
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
# EnterpriseService should NOT be called in community edition
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_community_edition_allows_transfer(self, mock_feature_service, mock_config):
"""Community edition should check billing plan but not call enterprise service."""
mock_config.ENTERPRISE_ENABLED = False
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True
mock_feature_service.get_features.return_value = mock_features
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_feature_service.get_features.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should block invitations when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = False
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"):
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should allow invitations when workspace policy is True."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = True
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service):
"""SANDBOX billing plan should block owner transfer before checking enterprise policy."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = False # SANDBOX plan
mock_feature_service.get_features.return_value = mock_features
with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
# Enterprise service should NOT be called since billing plan already blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service):
"""Enterprise edition should block transfer when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = False # Workspace policy blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_allows_transfer_when_both_enabled(
self, mock_feature_service, mock_config, mock_enterprise_service
):
"""Enterprise edition should allow transfer when both billing and workspace policy allow."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = True # Workspace policy allows
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.logger")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger):
"""On enterprise service error, should fail-open (allow) and log error."""
mock_config.ENTERPRISE_ENABLED = True
# Simulate enterprise service error
mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable")
# Should not raise (fail-open)
check_workspace_member_invite_permission("test-workspace-id")
# Should log the error
mock_logger.exception.assert_called_once()
assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args)