From 5e336c47fda65399c2686796470772f7804de354 Mon Sep 17 00:00:00 2001 From: Junyan Chin Date: Fri, 24 Apr 2026 15:53:14 +0800 Subject: [PATCH] feat: marketplace and oauth fixes (#35509) Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/.env.example | 5 + api/commands/plugin.py | 6 +- api/configs/feature/__init__.py | 22 + api/controllers/console/app/app.py | 26 + api/core/helper/creators.py | 41 ++ ...uth_encryption.py => system_encryption.py} | 82 +-- api/services/feature_service.py | 4 + .../tools/builtin_tools_manage_service.py | 4 +- .../trigger/trigger_provider_service.py | 4 +- .../unit_tests/core/helper/test_creators.py | 106 +++ .../utils/test_system_oauth_encryption.py | 48 +- .../services/test_trigger_provider_service.py | 4 +- .../test_builtin_tools_manage_service.py | 2 +- .../encryption/test_system_encryption.py | 619 ++++++++++++++++++ .../test_system_oauth_encryption.py | 619 ------------------ docker/.env.example | 5 + docker/docker-compose.yaml | 3 + web/app/account/oauth/authorize/page.tsx | 15 +- web/app/components/app-initializer.tsx | 2 +- .../app-publisher/__tests__/index.spec.tsx | 73 +++ .../components/app/app-publisher/index.tsx | 35 +- .../__tests__/index.spec.tsx | 4 +- .../app/create-from-dsl-modal/index.tsx | 2 +- .../components/apps/__tests__/index.spec.tsx | 96 ++- ...import-from-marketplace-template-modal.tsx | 182 +++++ web/app/components/apps/index.tsx | 38 ++ .../__tests__/panel-contextmenu.spec.tsx | 2 +- .../components/workflow/panel-contextmenu.tsx | 2 +- .../components/workflow/update-dsl-modal.tsx | 2 +- web/app/page.tsx | 31 +- web/app/signin/check-code/page.tsx | 2 +- .../components/mail-and-password-auth.tsx | 2 +- web/app/signin/invite-settings/page.tsx | 2 +- web/app/signin/normal-form.tsx | 2 +- web/app/signin/utils/post-login-redirect.ts | 68 +- web/contract/marketplace.ts | 13 + web/contract/router.ts | 3 +- web/i18n/en-US/app.json | 17 + web/i18n/en-US/workflow.json | 3 + web/i18n/zh-Hans/app.json | 17 + web/i18n/zh-Hans/workflow.json | 3 + web/next/navigation.ts | 2 + web/service/__tests__/base.spec.ts | 68 ++ web/service/apps.ts | 8 + web/service/base.ts | 18 +- web/service/marketplace-templates.ts | 18 + web/types/feature.ts | 2 + web/types/marketplace-template.ts | 11 + 48 files changed, 1604 insertions(+), 739 deletions(-) create mode 100644 api/core/helper/creators.py rename api/core/tools/utils/{system_oauth_encryption.py => system_encryption.py} (57%) create mode 100644 api/tests/unit_tests/core/helper/test_creators.py create mode 100644 api/tests/unit_tests/utils/encryption/test_system_encryption.py delete mode 100644 api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py create mode 100644 web/app/components/apps/import-from-marketplace-template-modal.tsx create mode 100644 web/service/__tests__/base.spec.ts create mode 100644 web/service/marketplace-templates.ts create mode 100644 web/types/marketplace-template.ts diff --git a/api/.env.example b/api/.env.example index 6cfe0266c2..f6f65011ea 100644 --- a/api/.env.example +++ b/api/.env.example @@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y MARKETPLACE_ENABLED=true MARKETPLACE_API_URL=https://marketplace.dify.ai +# Creators Platform configuration +CREATORS_PLATFORM_FEATURES_ENABLED=true +CREATORS_PLATFORM_API_URL=https://creators.dify.ai +CREATORS_PLATFORM_OAUTH_CLIENT_ID= + # Endpoint configuration ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} diff --git a/api/commands/plugin.py b/api/commands/plugin.py index c34391025a..8bd5392d7b 100644 --- a/api/commands/plugin.py +++ b/api/commands/plugin.py @@ -11,7 +11,7 @@ from configs import dify_config from core.helper import encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.plugin import PluginInstaller -from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params +from core.tools.utils.system_encryption import encrypt_system_params from extensions.ext_database import db from models import Tenant from models.oauth import DatasourceOauthParamConfig, DatasourceProvider @@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) @@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ae49ae47d0..52e33c1789 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings): ) +class CreatorsPlatformConfig(BaseSettings): + """ + Configuration for Creators Platform integration + """ + + CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field( + description="Enable or disable Creators Platform features", + default=True, + ) + + CREATORS_PLATFORM_API_URL: HttpUrl = Field( + description="Creators Platform API URL", + default=HttpUrl("https://creators.dify.ai"), + ) + + CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field( + description="OAuth client ID for Creators Platform integration", + default="", + ) + + class EndpointConfig(BaseSettings): """ Configuration for various application endpoints and URLs @@ -1379,6 +1400,7 @@ class FeatureConfig( AuthConfig, # Changed from OAuthConfig to AuthConfig BillingConfig, CodeExecutionSandboxConfig, + CreatorsPlatformConfig, TriggerConfig, AsyncWorkflowConfig, PluginConfig, diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9102983d86..a736fc8bc8 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -692,6 +692,32 @@ class AppExportApi(Resource): return payload.model_dump(mode="json") +@console_ns.route("/apps//publish-to-creators-platform") +class AppPublishToCreatorsPlatformApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=None) + @edit_permission_required + def post(self, app_model): + """Publish app to Creators Platform""" + from configs import dify_config + from core.helper.creators import get_redirect_url, upload_dsl + + if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + return {"error": "Creators Platform features are not enabled"}, 403 + + current_user, _ = current_account_with_tenant() + + dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False) + dsl_bytes = dsl_content.encode("utf-8") + + claim_code = upload_dsl(dsl_bytes) + redirect_url = get_redirect_url(str(current_user.id), claim_code) + + return {"redirect_url": redirect_url} + + @console_ns.route("/apps//name") class AppNameApi(Resource): @console_ns.doc("check_app_name") diff --git a/api/core/helper/creators.py b/api/core/helper/creators.py new file mode 100644 index 0000000000..b01e16f18a --- /dev/null +++ b/api/core/helper/creators.py @@ -0,0 +1,41 @@ +""" +Helper module for Creators Platform integration. + +Provides functionality to upload DSL files to the Creators Platform +and generate redirect URLs with OAuth authorization codes. +""" + +import logging +from urllib.parse import urlencode + +import httpx +from yarl import URL + +from configs import dify_config + +logger = logging.getLogger(__name__) + +creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL)) + + +def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str: + url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload") + response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30) + response.raise_for_status() + data = response.json() + claim_code = data.get("data", {}).get("claim_code") + if not claim_code: + raise ValueError("Creators Platform did not return a valid claim_code") + return claim_code + + +def get_redirect_url(user_account_id: str, claim_code: str) -> str: + base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/") + params: dict[str, str] = {"dsl_claim_code": claim_code} + client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "") + if client_id: + from services.oauth_server import OAuthServerService + + oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id) + params["oauth_code"] = oauth_code + return f"{base_url}?{urlencode(params)}" diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_encryption.py similarity index 57% rename from api/core/tools/utils/system_oauth_encryption.py rename to api/core/tools/utils/system_encryption.py index 6b7007842d..ca7e6a13fe 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_encryption.py @@ -14,23 +14,23 @@ from configs import dify_config logger = logging.getLogger(__name__) -class OAuthEncryptionError(Exception): - """OAuth encryption/decryption specific error""" +class EncryptionError(Exception): + """Encryption/decryption specific error""" pass -class SystemOAuthEncrypter: +class SystemEncrypter: """ - A simple OAuth parameters encrypter using AES-CBC encryption. + A simple parameters encrypter using AES-CBC encryption. - This class provides methods to encrypt and decrypt OAuth parameters + This class provides methods to encrypt and decrypt parameters using AES-CBC mode with a key derived from the application's SECRET_KEY. """ def __init__(self, secret_key: str | None = None): """ - Initialize the OAuth encrypter. + Initialize the encrypter. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY @@ -43,19 +43,19 @@ class SystemOAuthEncrypter: # Generate a fixed 256-bit key using SHA-256 self.key = hashlib.sha256(secret_key.encode()).digest() - def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + def encrypt_params(self, params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters. + Encrypt parameters. Args: - oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} Returns: Base64-encoded encrypted string Raises: - OAuthEncryptionError: If encryption fails - ValueError: If oauth_params is invalid + EncryptionError: If encryption fails + ValueError: If params is invalid """ try: @@ -66,7 +66,7 @@ class SystemOAuthEncrypter: cipher = AES.new(self.key, AES.MODE_CBC, iv) # Encrypt data - padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size) encrypted_data = cipher.encrypt(padded_data) # Combine IV and encrypted data @@ -76,20 +76,20 @@ class SystemOAuthEncrypter: return base64.b64encode(combined).decode() except Exception as e: - raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + raise EncryptionError(f"Encryption failed: {str(e)}") from e - def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters. + Decrypt parameters. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary Raises: - OAuthEncryptionError: If decryption fails + EncryptionError: If decryption fails ValueError: If encrypted_data is invalid """ if not isinstance(encrypted_data, str): @@ -118,70 +118,70 @@ class SystemOAuthEncrypter: unpadded_data = unpad(decrypted_data, AES.block_size) # Parse JSON - oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) - if not isinstance(oauth_params, dict): + if not isinstance(params, dict): raise ValueError("Decrypted data is not a valid dictionary") - return oauth_params + return params except Exception as e: - raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + raise EncryptionError(f"Decryption failed: {str(e)}") from e # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: +def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter: """ - Create an OAuth encrypter instance. + Create an encrypter instance. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - return SystemOAuthEncrypter(secret_key=secret_key) + return SystemEncrypter(secret_key=secret_key) # Global encrypter instance (for backward compatibility) -_oauth_encrypter: SystemOAuthEncrypter | None = None +_encrypter: SystemEncrypter | None = None -def get_system_oauth_encrypter() -> SystemOAuthEncrypter: +def get_system_encrypter() -> SystemEncrypter: """ - Get the global OAuth encrypter instance. + Get the global encrypter instance. Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - global _oauth_encrypter - if _oauth_encrypter is None: - _oauth_encrypter = SystemOAuthEncrypter() - return _oauth_encrypter + global _encrypter + if _encrypter is None: + _encrypter = SystemEncrypter() + return _encrypter # Convenience functions for backward compatibility -def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: +def encrypt_system_params(params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters using the global encrypter. + Encrypt parameters using the global encrypter. Args: - oauth_params: OAuth parameters dictionary + params: Parameters dictionary Returns: Base64-encoded encrypted string """ - return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + return get_system_encrypter().encrypt_params(params) -def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: +def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters using the global encrypter. + Decrypt parameters using the global encrypter. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary """ - return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) + return get_system_encrypter().decrypt_params(encrypted_data) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index e18eb096c9..38518378f7 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel): enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() trial_models: list[str] = [] + enable_creators_platform: bool = False enable_trial_app: bool = False enable_explore_banner: bool = False @@ -241,6 +242,9 @@ class FeatureService: if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True + if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + system_features.enable_creators_platform = True + return system_features @classmethod diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 7bd056b8a0..b8242ab3a5 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_provider_encrypter -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import ToolProviderID @@ -521,7 +521,7 @@ class BuiltinToolManageService: ) if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 6e14d996ea..b8a76e4945 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, @@ -635,7 +635,7 @@ class TriggerProviderService: if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/tests/unit_tests/core/helper/test_creators.py b/api/tests/unit_tests/core/helper/test_creators.py new file mode 100644 index 0000000000..df67d3f513 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_creators.py @@ -0,0 +1,106 @@ +"""Tests for the Creators Platform helper module.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from yarl import URL + + +@pytest.fixture(autouse=True) +def _patch_creators_url(monkeypatch): + """Patch the module-level creators_platform_api_url for all tests.""" + monkeypatch.setattr( + "core.helper.creators.creators_platform_api_url", + URL("https://creators.example.com"), + ) + + +class TestUploadDSL: + @patch("core.helper.creators.httpx.post") + def test_returns_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {"claim_code": "abc123"}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + result = upload_dsl(b"app: demo", "demo.yaml") + + assert result == "abc123" + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert "anonymous-upload" in call_kwargs.args[0] + assert call_kwargs.kwargs["timeout"] == 30 + + @patch("core.helper.creators.httpx.post") + def test_raises_on_missing_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(ValueError, match="claim_code"): + upload_dsl(b"app: demo") + + @patch("core.helper.creators.httpx.post") + def test_raises_on_http_error(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server Error", + request=MagicMock(), + response=MagicMock(), + ) + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(httpx.HTTPStatusError): + upload_dsl(b"app: demo") + + +class TestGetRedirectUrl: + @patch("core.helper.creators.dify_config") + def test_without_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code" not in url + assert url.startswith("https://creators.example.com") + + @patch("core.helper.creators.dify_config") + def test_with_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz" + + with patch( + "services.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="oauth-code-123", + ) as mock_sign: + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + mock_sign.assert_called_once_with("client-xyz", "user-1") + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code=oauth-code-123" in url + + @patch("core.helper.creators.dify_config") + def test_strips_trailing_slash(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert url.startswith("https://creators.example.com?") + assert "creators.example.com/?" not in url diff --git a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py index 5691f33e65..6bb86ebe78 100644 --- a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py @@ -2,50 +2,50 @@ from __future__ import annotations import pytest -from core.tools.utils import system_oauth_encryption as oauth_encryption -from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter +from core.tools.utils import system_encryption as encryption +from core.tools.utils.system_encryption import EncryptionError, SystemEncrypter -def test_system_oauth_encrypter_roundtrip(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_roundtrip(): + encrypter = SystemEncrypter(secret_key="test-secret") payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"} - encrypted = encrypter.encrypt_oauth_params(payload) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(payload) + decrypted = encrypter.decrypt_params(encrypted) assert encrypted assert dict(decrypted) == payload -def test_system_oauth_encrypter_decrypt_validates_input(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_decrypt_validates_input(): + encrypter = SystemEncrypter(secret_key="test-secret") with pytest.raises(ValueError, match="must be a string"): - encrypter.decrypt_oauth_params(123) # type: ignore[arg-type] + encrypter.decrypt_params(123) # type: ignore[arg-type] with pytest.raises(ValueError, match="cannot be empty"): - encrypter.decrypt_oauth_params("") + encrypter.decrypt_params("") -def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_raises_error_for_invalid_ciphertext(): + encrypter = SystemEncrypter(secret_key="test-secret") - with pytest.raises(OAuthEncryptionError, match="Decryption failed"): - encrypter.decrypt_oauth_params("not-base64") + with pytest.raises(EncryptionError, match="Decryption failed"): + encrypter.decrypt_params("not-base64") -def test_system_oauth_helpers_use_global_cached_instance(monkeypatch): - monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None) - monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret") +def test_system_helpers_use_global_cached_instance(monkeypatch): + monkeypatch.setattr(encryption, "_encrypter", None) + monkeypatch.setattr("core.tools.utils.system_encryption.dify_config.SECRET_KEY", "global-secret") - first = oauth_encryption.get_system_oauth_encrypter() - second = oauth_encryption.get_system_oauth_encrypter() + first = encryption.get_system_encrypter() + second = encryption.get_system_encrypter() assert first is second - encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"}) - assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"} + encrypted = encryption.encrypt_system_params({"k": "v"}) + assert encryption.decrypt_system_params(encrypted) == {"k": "v"} -def test_create_system_oauth_encrypter_factory(): - encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret") - assert isinstance(encrypter, SystemOAuthEncrypter) +def test_create_system_encrypter_factory(): + encrypter = encryption.create_system_encrypter(secret_key="factory-secret") + assert isinstance(encrypter, SystemEncrypter) diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py index ebf1b36610..6eba60e5f1 100644 --- a/api/tests/unit_tests/services/test_trigger_provider_service.py +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -694,7 +694,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified( _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( - "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + "services.trigger.trigger_provider_service.decrypt_system_params", return_value={"client_id": "system"}, ) @@ -716,7 +716,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails( _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( - "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + "services.trigger.trigger_provider_service.decrypt_system_params", side_effect=RuntimeError("bad data"), ) diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py index 79a2d30f57..ce0d94398d 100644 --- a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -280,7 +280,7 @@ class TestGetOauthClient: assert result == {"client_id": "id", "client_secret": "secret"} - @patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"}) + @patch(f"{MODULE}.decrypt_system_params", return_value={"sys_key": "sys_val"}) @patch(f"{MODULE}.PluginService") @patch(f"{MODULE}.create_provider_encrypter") @patch(f"{MODULE}.ToolManager") diff --git a/api/tests/unit_tests/utils/encryption/test_system_encryption.py b/api/tests/unit_tests/utils/encryption/test_system_encryption.py new file mode 100644 index 0000000000..0435facfdb --- /dev/null +++ b/api/tests/unit_tests/utils/encryption/test_system_encryption.py @@ -0,0 +1,619 @@ +import base64 +import hashlib +from unittest.mock import patch + +import pytest +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad + +from core.tools.utils.system_encryption import ( + EncryptionError, + SystemEncrypter, + create_system_encrypter, + decrypt_system_params, + encrypt_system_params, + get_system_encrypter, +) + + +class TestSystemEncrypter: + """Test cases for SystemEncrypter class""" + + def test_init_with_secret_key(self): + """Test initialization with provided secret key""" + secret_key = "test_secret_key" + encrypter = SystemEncrypter(secret_key=secret_key) + expected_key = hashlib.sha256(secret_key.encode()).digest() + assert encrypter.key == expected_key + + def test_init_with_none_secret_key(self): + """Test initialization with None secret key falls back to config""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = SystemEncrypter(secret_key=None) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + def test_init_with_empty_secret_key(self): + """Test initialization with empty secret key""" + encrypter = SystemEncrypter(secret_key="") + expected_key = hashlib.sha256(b"").digest() + assert encrypter.key == expected_key + + def test_init_without_secret_key_uses_config(self): + """Test initialization without secret key uses config""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "default_secret" + encrypter = SystemEncrypter() + expected_key = hashlib.sha256(b"default_secret").digest() + assert encrypter.key == expected_key + + def test_encrypt_params_basic(self): + """Test basic parameters encryption""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypter.encrypt_params(params) + + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + # Should be valid base64 + try: + base64.b64decode(encrypted) + except Exception: + pytest.fail("Encrypted result is not valid base64") + + def test_encrypt_params_empty_dict(self): + """Test encryption with empty dictionary""" + encrypter = SystemEncrypter("test_secret") + params = {} + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_complex_data(self): + """Test encryption with complex data structures""" + encrypter = SystemEncrypter("test_secret") + params = { + "client_id": "test_id", + "client_secret": "test_secret", + "scopes": ["read", "write", "admin"], + "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, + "numeric_value": 42, + "boolean_value": False, + "null_value": None, + } + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_unicode_data(self): + """Test encryption with unicode data""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"} + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_large_data(self): + """Test encryption with large data""" + encrypter = SystemEncrypter("test_secret") + params = { + "client_id": "test_id", + "large_data": "x" * 10000, # 10KB of data + } + + encrypted = encrypter.encrypt_params(params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_params_invalid_input(self): + """Test encryption with invalid input types""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(Exception): # noqa: B017 + encrypter.encrypt_params(None) + + with pytest.raises(Exception): # noqa: B017 + encrypter.encrypt_params("not_a_dict") + + def test_decrypt_params_basic(self): + """Test basic parameters decryption""" + encrypter = SystemEncrypter("test_secret") + original_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_empty_dict(self): + """Test decryption of empty dictionary""" + encrypter = SystemEncrypter("test_secret") + original_params = {} + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_complex_data(self): + """Test decryption with complex data structures""" + encrypter = SystemEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "client_secret": "test_secret", + "scopes": ["read", "write", "admin"], + "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, + "numeric_value": 42, + "boolean_value": False, + "null_value": None, + } + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_unicode_data(self): + """Test decryption with unicode data""" + encrypter = SystemEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "client_secret": "test_secret", + "description": "This is a test case 🚀", + } + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_large_data(self): + """Test decryption with large data""" + encrypter = SystemEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "large_data": "x" * 10000, # 10KB of data + } + + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_params_invalid_base64(self): + """Test decryption with invalid base64 data""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(EncryptionError): + encrypter.decrypt_params("invalid_base64!") + + def test_decrypt_params_empty_string(self): + """Test decryption with empty string""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params("") + + assert "encrypted_data cannot be empty" in str(exc_info.value) + + def test_decrypt_params_non_string_input(self): + """Test decryption with non-string input""" + encrypter = SystemEncrypter("test_secret") + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params(123) + + assert "encrypted_data must be a string" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params(None) + + assert "encrypted_data must be a string" in str(exc_info.value) + + def test_decrypt_params_too_short_data(self): + """Test decryption with too short encrypted data""" + encrypter = SystemEncrypter("test_secret") + + # Create data that's too short (less than 32 bytes) + short_data = base64.b64encode(b"short").decode() + + with pytest.raises(EncryptionError) as exc_info: + encrypter.decrypt_params(short_data) + + assert "Invalid encrypted data format" in str(exc_info.value) + + def test_decrypt_params_corrupted_data(self): + """Test decryption with corrupted data""" + encrypter = SystemEncrypter("test_secret") + + # Create corrupted data (valid base64 but invalid encrypted content) + corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage + + with pytest.raises(EncryptionError): + encrypter.decrypt_params(corrupted_data) + + def test_decrypt_params_wrong_key(self): + """Test decryption with wrong key""" + encrypter1 = SystemEncrypter("secret1") + encrypter2 = SystemEncrypter("secret2") + + original_params = {"client_id": "test_id", "client_secret": "test_secret"} + encrypted = encrypter1.encrypt_params(original_params) + + with pytest.raises(EncryptionError): + encrypter2.decrypt_params(encrypted) + + def test_encryption_decryption_consistency(self): + """Test that encryption and decryption are consistent""" + encrypter = SystemEncrypter("test_secret") + + test_cases = [ + {}, + {"simple": "value"}, + {"client_id": "id", "client_secret": "secret"}, + {"complex": {"nested": {"deep": "value"}}}, + {"unicode": "test 🚀"}, + {"numbers": 42, "boolean": True, "null": None}, + {"array": [1, 2, 3, "four", {"five": 5}]}, + ] + + for original_params in test_cases: + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == original_params, f"Failed for case: {original_params}" + + def test_encryption_randomness(self): + """Test that encryption produces different results for same input""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted1 = encrypter.encrypt_params(params) + encrypted2 = encrypter.encrypt_params(params) + + # Should be different due to random IV + assert encrypted1 != encrypted2 + + # But should decrypt to same result + decrypted1 = encrypter.decrypt_params(encrypted1) + decrypted2 = encrypter.decrypt_params(encrypted2) + assert decrypted1 == decrypted2 == params + + def test_different_secret_keys_produce_different_results(self): + """Test that different secret keys produce different encrypted results""" + encrypter1 = SystemEncrypter("secret1") + encrypter2 = SystemEncrypter("secret2") + + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted1 = encrypter1.encrypt_params(params) + encrypted2 = encrypter2.encrypt_params(params) + + # Should produce different encrypted results + assert encrypted1 != encrypted2 + + # But each should decrypt correctly with its own key + decrypted1 = encrypter1.decrypt_params(encrypted1) + decrypted2 = encrypter2.decrypt_params(encrypted2) + assert decrypted1 == decrypted2 == params + + @patch("core.tools.utils.system_encryption.get_random_bytes") + def test_encrypt_params_crypto_error(self, mock_get_random_bytes): + """Test encryption when crypto operation fails""" + mock_get_random_bytes.side_effect = Exception("Crypto error") + + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id"} + + with pytest.raises(EncryptionError) as exc_info: + encrypter.encrypt_params(params) + + assert "Encryption failed" in str(exc_info.value) + + @patch("core.tools.utils.system_encryption.TypeAdapter") + def test_encrypt_params_serialization_error(self, mock_type_adapter): + """Test encryption when JSON serialization fails""" + mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error") + + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id"} + + with pytest.raises(EncryptionError) as exc_info: + encrypter.encrypt_params(params) + + assert "Encryption failed" in str(exc_info.value) + + def test_decrypt_params_invalid_json(self): + """Test decryption with invalid JSON data""" + encrypter = SystemEncrypter("test_secret") + + # Create valid encrypted data but with invalid JSON content + iv = get_random_bytes(16) + cipher = AES.new(encrypter.key, AES.MODE_CBC, iv) + invalid_json = b"invalid json content" + padded_data = pad(invalid_json, AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + combined = iv + encrypted_data + encoded = base64.b64encode(combined).decode() + + with pytest.raises(EncryptionError): + encrypter.decrypt_params(encoded) + + def test_key_derivation_consistency(self): + """Test that key derivation is consistent""" + secret_key = "test_secret" + encrypter1 = SystemEncrypter(secret_key) + encrypter2 = SystemEncrypter(secret_key) + + assert encrypter1.key == encrypter2.key + + # Keys should be 32 bytes (256 bits) + assert len(encrypter1.key) == 32 + + +class TestFactoryFunctions: + """Test cases for factory functions""" + + def test_create_system_encrypter_with_secret(self): + """Test factory function with secret key""" + secret_key = "test_secret" + encrypter = create_system_encrypter(secret_key) + + assert isinstance(encrypter, SystemEncrypter) + expected_key = hashlib.sha256(secret_key.encode()).digest() + assert encrypter.key == expected_key + + def test_create_system_encrypter_without_secret(self): + """Test factory function without secret key""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = create_system_encrypter() + + assert isinstance(encrypter, SystemEncrypter) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + def test_create_system_encrypter_with_none_secret(self): + """Test factory function with None secret key""" + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = create_system_encrypter(None) + + assert isinstance(encrypter, SystemEncrypter) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + +class TestGlobalEncrypterInstance: + """Test cases for global encrypter instance""" + + def test_get_system_encrypter_singleton(self): + """Test that get_system_encrypter returns singleton instance""" + # Clear the global instance first + import core.tools.utils.system_encryption + + core.tools.utils.system_encryption._encrypter = None + + encrypter1 = get_system_encrypter() + encrypter2 = get_system_encrypter() + + assert encrypter1 is encrypter2 + assert isinstance(encrypter1, SystemEncrypter) + + def test_get_system_encrypter_uses_config(self): + """Test that global encrypter uses config""" + # Clear the global instance first + import core.tools.utils.system_encryption + + core.tools.utils.system_encryption._encrypter = None + + with patch("core.tools.utils.system_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "global_secret" + encrypter = get_system_encrypter() + + expected_key = hashlib.sha256(b"global_secret").digest() + assert encrypter.key == expected_key + + +class TestConvenienceFunctions: + """Test cases for convenience functions""" + + def test_encrypt_system_params(self): + """Test encrypt_system_params convenience function""" + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypt_system_params(params) + + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_decrypt_system_params(self): + """Test decrypt_system_params convenience function""" + params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypt_system_params(params) + decrypted = decrypt_system_params(encrypted) + + assert decrypted == params + + def test_convenience_functions_consistency(self): + """Test that convenience functions work consistently""" + test_cases = [ + {}, + {"simple": "value"}, + {"client_id": "id", "client_secret": "secret"}, + {"complex": {"nested": {"deep": "value"}}}, + {"unicode": "test 🚀"}, + {"numbers": 42, "boolean": True, "null": None}, + ] + + for original_params in test_cases: + encrypted = encrypt_system_params(original_params) + decrypted = decrypt_system_params(encrypted) + assert decrypted == original_params, f"Failed for case: {original_params}" + + def test_convenience_functions_with_errors(self): + """Test convenience functions with error conditions""" + # Test encryption with invalid input + with pytest.raises(Exception): # noqa: B017 + encrypt_system_params(None) + + # Test decryption with invalid input + with pytest.raises(ValueError): + decrypt_system_params("") + + with pytest.raises(ValueError): + decrypt_system_params(None) + + +class TestErrorHandling: + """Test cases for error handling""" + + def test_encryption_error_inheritance(self): + """Test that EncryptionError is a proper exception""" + error = EncryptionError("Test error") + assert isinstance(error, Exception) + assert str(error) == "Test error" + + def test_encryption_error_with_cause(self): + """Test EncryptionError with cause""" + original_error = ValueError("Original error") + error = EncryptionError("Wrapper error") + error.__cause__ = original_error + + assert isinstance(error, Exception) + assert str(error) == "Wrapper error" + assert error.__cause__ is original_error + + def test_error_messages_are_informative(self): + """Test that error messages are informative""" + encrypter = SystemEncrypter("test_secret") + + # Test empty string error + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params("") + assert "encrypted_data cannot be empty" in str(exc_info.value) + + # Test non-string error + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_params(123) + assert "encrypted_data must be a string" in str(exc_info.value) + + # Test invalid format error + short_data = base64.b64encode(b"short").decode() + with pytest.raises(EncryptionError) as exc_info: + encrypter.decrypt_params(short_data) + assert "Invalid encrypted data format" in str(exc_info.value) + + +class TestEdgeCases: + """Test cases for edge cases and boundary conditions""" + + def test_very_long_secret_key(self): + """Test with very long secret key""" + long_secret = "x" * 10000 + encrypter = SystemEncrypter(long_secret) + + # Key should still be 32 bytes due to SHA-256 + assert len(encrypter.key) == 32 + + # Should still work normally + params = {"client_id": "test_id"} + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_special_characters_in_secret_key(self): + """Test with special characters in secret key""" + special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀" + encrypter = SystemEncrypter(special_secret) + + params = {"client_id": "test_id"} + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_empty_values_in_params(self): + """Test with empty values in params""" + params = { + "client_id": "", + "client_secret": "", + "empty_dict": {}, + "empty_list": [], + "empty_string": "", + "zero": 0, + "false": False, + "none": None, + } + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_deeply_nested_params(self): + """Test with deeply nested params""" + params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}} + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_params_with_all_json_types(self): + """Test with all JSON-supported data types""" + params = { + "string": "test_string", + "integer": 42, + "float": 3.14159, + "boolean_true": True, + "boolean_false": False, + "null_value": None, + "empty_string": "", + "array": [1, "two", 3.0, True, False, None], + "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True}, + } + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + +class TestPerformance: + """Test cases for performance considerations""" + + def test_large_params(self): + """Test with large params""" + large_value = "x" * 100000 # 100KB + params = {"client_id": "test_id", "large_data": large_value} + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_many_fields_params(self): + """Test with many fields in params""" + params = {f"field_{i}": f"value_{i}" for i in range(1000)} + + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params + + def test_repeated_encryption_decryption(self): + """Test repeated encryption and decryption operations""" + encrypter = SystemEncrypter("test_secret") + params = {"client_id": "test_id", "client_secret": "test_secret"} + + # Test multiple rounds of encryption/decryption + for i in range(100): + encrypted = encrypter.encrypt_params(params) + decrypted = encrypter.decrypt_params(encrypted) + assert decrypted == params diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py deleted file mode 100644 index e2607f0fb1..0000000000 --- a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py +++ /dev/null @@ -1,619 +0,0 @@ -import base64 -import hashlib -from unittest.mock import patch - -import pytest -from Crypto.Cipher import AES -from Crypto.Random import get_random_bytes -from Crypto.Util.Padding import pad - -from core.tools.utils.system_oauth_encryption import ( - OAuthEncryptionError, - SystemOAuthEncrypter, - create_system_oauth_encrypter, - decrypt_system_oauth_params, - encrypt_system_oauth_params, - get_system_oauth_encrypter, -) - - -class TestSystemOAuthEncrypter: - """Test cases for SystemOAuthEncrypter class""" - - def test_init_with_secret_key(self): - """Test initialization with provided secret key""" - secret_key = "test_secret_key" - encrypter = SystemOAuthEncrypter(secret_key=secret_key) - expected_key = hashlib.sha256(secret_key.encode()).digest() - assert encrypter.key == expected_key - - def test_init_with_none_secret_key(self): - """Test initialization with None secret key falls back to config""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "config_secret" - encrypter = SystemOAuthEncrypter(secret_key=None) - expected_key = hashlib.sha256(b"config_secret").digest() - assert encrypter.key == expected_key - - def test_init_with_empty_secret_key(self): - """Test initialization with empty secret key""" - encrypter = SystemOAuthEncrypter(secret_key="") - expected_key = hashlib.sha256(b"").digest() - assert encrypter.key == expected_key - - def test_init_without_secret_key_uses_config(self): - """Test initialization without secret key uses config""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "default_secret" - encrypter = SystemOAuthEncrypter() - expected_key = hashlib.sha256(b"default_secret").digest() - assert encrypter.key == expected_key - - def test_encrypt_oauth_params_basic(self): - """Test basic OAuth parameters encryption""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - # Should be valid base64 - try: - base64.b64decode(encrypted) - except Exception: - pytest.fail("Encrypted result is not valid base64") - - def test_encrypt_oauth_params_empty_dict(self): - """Test encryption with empty dictionary""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {} - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_complex_data(self): - """Test encryption with complex data structures""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = { - "client_id": "test_id", - "client_secret": "test_secret", - "scopes": ["read", "write", "admin"], - "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, - "numeric_value": 42, - "boolean_value": False, - "null_value": None, - } - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_unicode_data(self): - """Test encryption with unicode data""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"} - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_large_data(self): - """Test encryption with large data""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = { - "client_id": "test_id", - "large_data": "x" * 10000, # 10KB of data - } - - encrypted = encrypter.encrypt_oauth_params(oauth_params) - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_encrypt_oauth_params_invalid_input(self): - """Test encryption with invalid input types""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params(None) - - with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params("not_a_dict") - - def test_decrypt_oauth_params_basic(self): - """Test basic OAuth parameters decryption""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_empty_dict(self): - """Test decryption of empty dictionary""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = {} - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_complex_data(self): - """Test decryption with complex data structures""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = { - "client_id": "test_id", - "client_secret": "test_secret", - "scopes": ["read", "write", "admin"], - "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, - "numeric_value": 42, - "boolean_value": False, - "null_value": None, - } - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_unicode_data(self): - """Test decryption with unicode data""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = { - "client_id": "test_id", - "client_secret": "test_secret", - "description": "This is a test case 🚀", - } - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_large_data(self): - """Test decryption with large data""" - encrypter = SystemOAuthEncrypter("test_secret") - original_params = { - "client_id": "test_id", - "large_data": "x" * 10000, # 10KB of data - } - - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - - assert decrypted == original_params - - def test_decrypt_oauth_params_invalid_base64(self): - """Test decryption with invalid base64 data""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params("invalid_base64!") - - def test_decrypt_oauth_params_empty_string(self): - """Test decryption with empty string""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params("") - - assert "encrypted_data cannot be empty" in str(exc_info.value) - - def test_decrypt_oauth_params_non_string_input(self): - """Test decryption with non-string input""" - encrypter = SystemOAuthEncrypter("test_secret") - - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) - - assert "encrypted_data must be a string" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(None) - - assert "encrypted_data must be a string" in str(exc_info.value) - - def test_decrypt_oauth_params_too_short_data(self): - """Test decryption with too short encrypted data""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Create data that's too short (less than 32 bytes) - short_data = base64.b64encode(b"short").decode() - - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.decrypt_oauth_params(short_data) - - assert "Invalid encrypted data format" in str(exc_info.value) - - def test_decrypt_oauth_params_corrupted_data(self): - """Test decryption with corrupted data""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Create corrupted data (valid base64 but invalid encrypted content) - corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage - - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params(corrupted_data) - - def test_decrypt_oauth_params_wrong_key(self): - """Test decryption with wrong key""" - encrypter1 = SystemOAuthEncrypter("secret1") - encrypter2 = SystemOAuthEncrypter("secret2") - - original_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted = encrypter1.encrypt_oauth_params(original_params) - - with pytest.raises(OAuthEncryptionError): - encrypter2.decrypt_oauth_params(encrypted) - - def test_encryption_decryption_consistency(self): - """Test that encryption and decryption are consistent""" - encrypter = SystemOAuthEncrypter("test_secret") - - test_cases = [ - {}, - {"simple": "value"}, - {"client_id": "id", "client_secret": "secret"}, - {"complex": {"nested": {"deep": "value"}}}, - {"unicode": "test 🚀"}, - {"numbers": 42, "boolean": True, "null": None}, - {"array": [1, 2, 3, "four", {"five": 5}]}, - ] - - for original_params in test_cases: - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == original_params, f"Failed for case: {original_params}" - - def test_encryption_randomness(self): - """Test that encryption produces different results for same input""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted1 = encrypter.encrypt_oauth_params(oauth_params) - encrypted2 = encrypter.encrypt_oauth_params(oauth_params) - - # Should be different due to random IV - assert encrypted1 != encrypted2 - - # But should decrypt to same result - decrypted1 = encrypter.decrypt_oauth_params(encrypted1) - decrypted2 = encrypter.decrypt_oauth_params(encrypted2) - assert decrypted1 == decrypted2 == oauth_params - - def test_different_secret_keys_produce_different_results(self): - """Test that different secret keys produce different encrypted results""" - encrypter1 = SystemOAuthEncrypter("secret1") - encrypter2 = SystemOAuthEncrypter("secret2") - - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted1 = encrypter1.encrypt_oauth_params(oauth_params) - encrypted2 = encrypter2.encrypt_oauth_params(oauth_params) - - # Should produce different encrypted results - assert encrypted1 != encrypted2 - - # But each should decrypt correctly with its own key - decrypted1 = encrypter1.decrypt_oauth_params(encrypted1) - decrypted2 = encrypter2.decrypt_oauth_params(encrypted2) - assert decrypted1 == decrypted2 == oauth_params - - @patch("core.tools.utils.system_oauth_encryption.get_random_bytes") - def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes): - """Test encryption when crypto operation fails""" - mock_get_random_bytes.side_effect = Exception("Crypto error") - - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id"} - - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.encrypt_oauth_params(oauth_params) - - assert "Encryption failed" in str(exc_info.value) - - @patch("core.tools.utils.system_oauth_encryption.TypeAdapter") - def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter): - """Test encryption when JSON serialization fails""" - mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error") - - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id"} - - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.encrypt_oauth_params(oauth_params) - - assert "Encryption failed" in str(exc_info.value) - - def test_decrypt_oauth_params_invalid_json(self): - """Test decryption with invalid JSON data""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Create valid encrypted data but with invalid JSON content - iv = get_random_bytes(16) - cipher = AES.new(encrypter.key, AES.MODE_CBC, iv) - invalid_json = b"invalid json content" - padded_data = pad(invalid_json, AES.block_size) - encrypted_data = cipher.encrypt(padded_data) - combined = iv + encrypted_data - encoded = base64.b64encode(combined).decode() - - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params(encoded) - - def test_key_derivation_consistency(self): - """Test that key derivation is consistent""" - secret_key = "test_secret" - encrypter1 = SystemOAuthEncrypter(secret_key) - encrypter2 = SystemOAuthEncrypter(secret_key) - - assert encrypter1.key == encrypter2.key - - # Keys should be 32 bytes (256 bits) - assert len(encrypter1.key) == 32 - - -class TestFactoryFunctions: - """Test cases for factory functions""" - - def test_create_system_oauth_encrypter_with_secret(self): - """Test factory function with secret key""" - secret_key = "test_secret" - encrypter = create_system_oauth_encrypter(secret_key) - - assert isinstance(encrypter, SystemOAuthEncrypter) - expected_key = hashlib.sha256(secret_key.encode()).digest() - assert encrypter.key == expected_key - - def test_create_system_oauth_encrypter_without_secret(self): - """Test factory function without secret key""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "config_secret" - encrypter = create_system_oauth_encrypter() - - assert isinstance(encrypter, SystemOAuthEncrypter) - expected_key = hashlib.sha256(b"config_secret").digest() - assert encrypter.key == expected_key - - def test_create_system_oauth_encrypter_with_none_secret(self): - """Test factory function with None secret key""" - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "config_secret" - encrypter = create_system_oauth_encrypter(None) - - assert isinstance(encrypter, SystemOAuthEncrypter) - expected_key = hashlib.sha256(b"config_secret").digest() - assert encrypter.key == expected_key - - -class TestGlobalEncrypterInstance: - """Test cases for global encrypter instance""" - - def test_get_system_oauth_encrypter_singleton(self): - """Test that get_system_oauth_encrypter returns singleton instance""" - # Clear the global instance first - import core.tools.utils.system_oauth_encryption - - core.tools.utils.system_oauth_encryption._oauth_encrypter = None - - encrypter1 = get_system_oauth_encrypter() - encrypter2 = get_system_oauth_encrypter() - - assert encrypter1 is encrypter2 - assert isinstance(encrypter1, SystemOAuthEncrypter) - - def test_get_system_oauth_encrypter_uses_config(self): - """Test that global encrypter uses config""" - # Clear the global instance first - import core.tools.utils.system_oauth_encryption - - core.tools.utils.system_oauth_encryption._oauth_encrypter = None - - with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: - mock_config.SECRET_KEY = "global_secret" - encrypter = get_system_oauth_encrypter() - - expected_key = hashlib.sha256(b"global_secret").digest() - assert encrypter.key == expected_key - - -class TestConvenienceFunctions: - """Test cases for convenience functions""" - - def test_encrypt_system_oauth_params(self): - """Test encrypt_system_oauth_params convenience function""" - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypt_system_oauth_params(oauth_params) - - assert isinstance(encrypted, str) - assert len(encrypted) > 0 - - def test_decrypt_system_oauth_params(self): - """Test decrypt_system_oauth_params convenience function""" - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - encrypted = encrypt_system_oauth_params(oauth_params) - decrypted = decrypt_system_oauth_params(encrypted) - - assert decrypted == oauth_params - - def test_convenience_functions_consistency(self): - """Test that convenience functions work consistently""" - test_cases = [ - {}, - {"simple": "value"}, - {"client_id": "id", "client_secret": "secret"}, - {"complex": {"nested": {"deep": "value"}}}, - {"unicode": "test 🚀"}, - {"numbers": 42, "boolean": True, "null": None}, - ] - - for original_params in test_cases: - encrypted = encrypt_system_oauth_params(original_params) - decrypted = decrypt_system_oauth_params(encrypted) - assert decrypted == original_params, f"Failed for case: {original_params}" - - def test_convenience_functions_with_errors(self): - """Test convenience functions with error conditions""" - # Test encryption with invalid input - with pytest.raises(Exception): # noqa: B017 - encrypt_system_oauth_params(None) - - # Test decryption with invalid input - with pytest.raises(ValueError): - decrypt_system_oauth_params("") - - with pytest.raises(ValueError): - decrypt_system_oauth_params(None) - - -class TestErrorHandling: - """Test cases for error handling""" - - def test_oauth_encryption_error_inheritance(self): - """Test that OAuthEncryptionError is a proper exception""" - error = OAuthEncryptionError("Test error") - assert isinstance(error, Exception) - assert str(error) == "Test error" - - def test_oauth_encryption_error_with_cause(self): - """Test OAuthEncryptionError with cause""" - original_error = ValueError("Original error") - error = OAuthEncryptionError("Wrapper error") - error.__cause__ = original_error - - assert isinstance(error, Exception) - assert str(error) == "Wrapper error" - assert error.__cause__ is original_error - - def test_error_messages_are_informative(self): - """Test that error messages are informative""" - encrypter = SystemOAuthEncrypter("test_secret") - - # Test empty string error - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params("") - assert "encrypted_data cannot be empty" in str(exc_info.value) - - # Test non-string error - with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) - assert "encrypted_data must be a string" in str(exc_info.value) - - # Test invalid format error - short_data = base64.b64encode(b"short").decode() - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.decrypt_oauth_params(short_data) - assert "Invalid encrypted data format" in str(exc_info.value) - - -class TestEdgeCases: - """Test cases for edge cases and boundary conditions""" - - def test_very_long_secret_key(self): - """Test with very long secret key""" - long_secret = "x" * 10000 - encrypter = SystemOAuthEncrypter(long_secret) - - # Key should still be 32 bytes due to SHA-256 - assert len(encrypter.key) == 32 - - # Should still work normally - oauth_params = {"client_id": "test_id"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_special_characters_in_secret_key(self): - """Test with special characters in secret key""" - special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀" - encrypter = SystemOAuthEncrypter(special_secret) - - oauth_params = {"client_id": "test_id"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_empty_values_in_oauth_params(self): - """Test with empty values in oauth params""" - oauth_params = { - "client_id": "", - "client_secret": "", - "empty_dict": {}, - "empty_list": [], - "empty_string": "", - "zero": 0, - "false": False, - "none": None, - } - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_deeply_nested_oauth_params(self): - """Test with deeply nested oauth params""" - oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}} - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_oauth_params_with_all_json_types(self): - """Test with all JSON-supported data types""" - oauth_params = { - "string": "test_string", - "integer": 42, - "float": 3.14159, - "boolean_true": True, - "boolean_false": False, - "null_value": None, - "empty_string": "", - "array": [1, "two", 3.0, True, False, None], - "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True}, - } - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - -class TestPerformance: - """Test cases for performance considerations""" - - def test_large_oauth_params(self): - """Test with large oauth params""" - large_value = "x" * 100000 # 100KB - oauth_params = {"client_id": "test_id", "large_data": large_value} - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_many_fields_oauth_params(self): - """Test with many fields in oauth params""" - oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)} - - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params - - def test_repeated_encryption_decryption(self): - """Test repeated encryption and decryption operations""" - encrypter = SystemOAuthEncrypter("test_secret") - oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - - # Test multiple rounds of encryption/decryption - for i in range(100): - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) - assert decrypted == oauth_params diff --git a/docker/.env.example b/docker/.env.example index ec7d572057..29741474fa 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1467,6 +1467,11 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} MARKETPLACE_ENABLED=true MARKETPLACE_API_URL=https://marketplace.dify.ai +# Creators Platform configuration +CREATORS_PLATFORM_FEATURES_ENABLED=true +CREATORS_PLATFORM_API_URL=https://creators.dify.ai +CREATORS_PLATFORM_OAUTH_CLIENT_ID= + FORCE_VERIFYING_SIGNATURE=true ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES=true diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index aaf099453a..60ba510f44 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -629,6 +629,9 @@ x-shared-env: &shared-api-worker-env ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} + CREATORS_PLATFORM_FEATURES_ENABLED: ${CREATORS_PLATFORM_FEATURES_ENABLED:-true} + CREATORS_PLATFORM_API_URL: ${CREATORS_PLATFORM_API_URL:-https://creators.dify.ai} + CREATORS_PLATFORM_OAUTH_CLIENT_ID: ${CREATORS_PLATFORM_OAUTH_CLIENT_ID:-} FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES: ${ENFORCE_LANGGENIUS_PLUGIN_SIGNATURES:-true} PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index dd95dc04ba..55666db193 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -16,9 +16,9 @@ import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' +import { setOAuthPendingRedirect } from '@/app/signin/utils/post-login-redirect' import { useRouter, useSearchParams } from '@/next/navigation' -import { isLegacyBase401, userProfileQueryOptions } from '@/service/use-common' +import { isLegacyBase401, useLogout, userProfileQueryOptions } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' function buildReturnUrl(pathname: string, search: string) { @@ -73,14 +73,17 @@ export default function OAuthAuthorize() { const userProfile = userProfileResp?.profile const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp() + const { mutateAsync: logout } = useLogout() const hasNotifiedRef = useRef(false) const isLoading = isOAuthLoading || isProfileLoading - const onLoginSwitchClick = () => { + const onLoginSwitchClick = async () => { try { - const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`) - setPostLoginRedirect(returnUrl) - router.push('/signin') + const returnUrl = buildReturnUrl('/account/oauth/authorize', `?${searchParams.toString()}`) + setOAuthPendingRedirect(returnUrl) + if (isLoggedIn) + await logout() + router.push(`/signin?redirect_url=${encodeURIComponent(returnUrl)}`) } catch { router.push('/signin') diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index 2c50312590..3d2af1ce61 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -85,7 +85,7 @@ export const AppInitializer = ({ return } - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) if (redirectUrl) { location.replace(redirectUrl) return diff --git a/web/app/components/app/app-publisher/__tests__/index.spec.tsx b/web/app/components/app/app-publisher/__tests__/index.spec.tsx index aa9cda8e34..5df331767b 100644 --- a/web/app/components/app/app-publisher/__tests__/index.spec.tsx +++ b/web/app/components/app/app-publisher/__tests__/index.spec.tsx @@ -80,8 +80,11 @@ vi.mock('@/service/explore', () => ({ fetchInstalledAppList: (...args: unknown[]) => mockFetchInstalledAppList(...args), })) +const mockPublishToCreatorsPlatform = vi.fn() + vi.mock('@/service/apps', () => ({ fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args), + publishToCreatorsPlatform: (...args: unknown[]) => mockPublishToCreatorsPlatform(...args), })) vi.mock('@/service/use-workflow', () => ({ @@ -434,6 +437,76 @@ describe('AppPublisher', () => { }) }) + it('should show marketplace button and open redirect URL on success', async () => { + mockPublishToCreatorsPlatform.mockResolvedValue({ redirect_url: 'https://marketplace.example.com/publish?code=abc' }) + const windowOpenSpy = vi.spyOn(window, 'open').mockImplementation(() => null) + + renderWithSystemFeatures( + , + { systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } }, + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('common.publishToMarketplace')) + + await waitFor(() => { + expect(mockPublishToCreatorsPlatform).toHaveBeenCalledWith({ appID: 'app-1' }) + expect(windowOpenSpy).toHaveBeenCalledWith('https://marketplace.example.com/publish?code=abc', '_blank') + }) + + windowOpenSpy.mockRestore() + }) + + it('should show toast error when publish to marketplace fails', async () => { + mockPublishToCreatorsPlatform.mockRejectedValue(new Error('network error')) + + renderWithSystemFeatures( + , + { systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } }, + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('common.publishToMarketplace')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('common.publishToMarketplaceFailed') + }) + }) + + it('should disable marketplace button when not yet published', () => { + renderWithSystemFeatures( + , + { systemFeatures: { webapp_auth: { enabled: true }, enable_creators_platform: true } }, + ) + + fireEvent.click(screen.getByText('common.publish')) + const marketplaceButton = screen.getByText('common.publishToMarketplace').closest('a, button, div[role="button"]') as HTMLElement + expect(marketplaceButton).toBeInTheDocument() + // clicking should not call the API because publishedAt is undefined + fireEvent.click(screen.getByText('common.publishToMarketplace')) + expect(mockPublishToCreatorsPlatform).not.toHaveBeenCalled() + }) + + it('should hide marketplace button when enable_creators_platform is false', () => { + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + expect(screen.queryByText('common.publishToMarketplace')).not.toBeInTheDocument() + }) + it('should keep access control open when app detail is unavailable during confirmation', async () => { mockAppDetail = null diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index b85e888557..fe6fe5806f 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -5,6 +5,7 @@ import type { PublishWorkflowParams } from '@/types/workflow' import { Button } from '@langgenius/dify-ui/button' import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover' import { toast } from '@langgenius/dify-ui/toast' +import { RiStoreLine } from '@remixicon/react' import { useSuspenseQuery } from '@tanstack/react-query' import { useKeyPress } from 'ahooks' import { @@ -26,7 +27,7 @@ import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control' -import { fetchAppDetailDirect } from '@/service/apps' +import { fetchAppDetailDirect, publishToCreatorsPlatform } from '@/service/apps' import { fetchInstalledAppList } from '@/service/explore' import { systemFeaturesQueryOptions } from '@/service/system-features' import { useInvalidateAppWorkflow } from '@/service/use-workflow' @@ -40,6 +41,7 @@ import { PublisherActionsSection, PublisherSummarySection, } from './sections' +import SuggestedAction from './suggested-action' import { getDisabledFunctionTooltip, getPublisherAppUrl, @@ -100,6 +102,7 @@ const AppPublisher = ({ const [showAppAccessControl, setShowAppAccessControl] = useState(false) const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false) + const [publishingToMarketplace, setPublishingToMarketplace] = useState(false) const workflowStore = useContext(WorkflowContext) const appDetail = useAppStore(state => state.appDetail) @@ -219,6 +222,23 @@ const AppPublisher = ({ } }, [appDetail, setAppDetail]) + const handlePublishToMarketplace = useCallback(async () => { + if (!appDetail?.id || publishingToMarketplace) + return + setPublishingToMarketplace(true) + try { + const res = await publishToCreatorsPlatform({ appID: appDetail.id }) + if (res.redirect_url) + window.open(res.redirect_url, '_blank') + } + catch { + toast.error(t('common.publishToMarketplaceFailed', { ns: 'workflow' })) + } + finally { + setPublishingToMarketplace(false) + } + }, [appDetail?.id, publishingToMarketplace, t]) + useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.p`, (e) => { e.preventDefault() if (publishDisabled || published) @@ -336,6 +356,19 @@ const AppPublisher = ({ workflowToolAvailable={workflowToolAvailable} workflowToolMessage={workflowToolMessage} /> + {systemFeatures.enable_creators_platform && ( +
+ } + disabled={!publishedAt || publishingToMarketplace} + onClick={handlePublishToMarketplace} + > + {publishingToMarketplace + ? t('common.publishingToMarketplace', { ns: 'workflow' }) + : t('common.publishToMarketplace', { ns: 'workflow' })} + +
+ )} { />, ) - expect(screen.getByText('importFromDSL'))!.toBeInTheDocument() + expect(screen.getByText('importApp'))!.toBeInTheDocument() await waitFor(() => { expect(screen.getByText('demo.yml'))!.toBeInTheDocument() @@ -161,7 +161,7 @@ describe('CreateFromDSLModal', () => { }) expect(screen.getByPlaceholderText('importFromDSLUrlPlaceholder'))!.toBeInTheDocument() - const closeTrigger = screen.getByText('importFromDSL').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement + const closeTrigger = screen.getByText('importApp').parentElement?.querySelector('.cursor-pointer.items-center') as HTMLElement fireEvent.click(closeTrigger) expect(handleClose).toHaveBeenCalledTimes(1) }) diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 4f99fe9027..bc5f352634 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -225,7 +225,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS onClose={noop} >
- {t('importFromDSL', { ns: 'app' })} + {t('importApp', { ns: 'app' })}
onClose()} diff --git a/web/app/components/apps/__tests__/index.spec.tsx b/web/app/components/apps/__tests__/index.spec.tsx index 2e0d1bcc84..94fa9f3484 100644 --- a/web/app/components/apps/__tests__/index.spec.tsx +++ b/web/app/components/apps/__tests__/index.spec.tsx @@ -7,9 +7,21 @@ import { useContextSelector } from 'use-context-selector' import AppListContext from '@/context/app-list-context' import { fetchAppDetail } from '@/service/explore' import { AppModeEnum } from '@/types/app' - import Apps from '../index' +vi.mock('@/next/dynamic', () => ({ + default: (loader: () => Promise<{ default: React.ComponentType }>) => { + const LazyComp = React.lazy(loader) + return function DynamicWrapper(props: Record) { + return React.createElement( + React.Suspense, + { fallback: null }, + React.createElement(LazyComp, props), + ) + } + }, +})) + let documentTitleCalls: string[] = [] let educationInitCalls: number = 0 const mockHandleImportDSL = vi.fn() @@ -65,6 +77,16 @@ vi.mock('@/hooks/use-import-dsl', () => ({ }), })) +const mockReplace = vi.fn() +let mockSearchParams = new URLSearchParams() + +vi.mock('@/next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), + useSearchParams: () => mockSearchParams, +})) + vi.mock('../list', () => { const MockList = () => { const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel) @@ -129,6 +151,16 @@ vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({ ), })) +vi.mock('../import-from-marketplace-template-modal', () => ({ + default: ({ templateId, onClose, onConfirm }: { templateId: string, onClose: () => void, onConfirm: (dsl: string) => void }) => ( +
+ {templateId} + + +
+ ), +})) + vi.mock('@/service/explore', () => ({ fetchAppDetail: vi.fn(), })) @@ -161,6 +193,8 @@ describe('Apps', () => { vi.clearAllMocks() documentTitleCalls = [] educationInitCalls = 0 + mockSearchParams = new URLSearchParams() + mockReplace.mockClear() mockFetchAppDetail.mockResolvedValue({ id: 'template-1', name: 'Sample App', @@ -304,6 +338,66 @@ describe('Apps', () => { }) }) + describe('Marketplace Template', () => { + it('should render the template modal when template-id is in search params', async () => { + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + expect(await screen.findByTestId('marketplace-template-modal')).toBeInTheDocument() + expect(screen.getByTestId('template-id')).toHaveTextContent('tpl-42') + }) + + it('should not render the template modal when no template-id is present', () => { + renderWithClient() + + expect(screen.queryByTestId('marketplace-template-modal')).not.toBeInTheDocument() + }) + + it('should close the template modal and remove template-id from URL', async () => { + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + fireEvent.click(await screen.findByTestId('close-template')) + + expect(mockReplace).toHaveBeenCalledTimes(1) + const replaceArg = mockReplace.mock.calls[0]![0] as string + expect(replaceArg).not.toContain('template-id') + }) + + it('should import DSL from marketplace template on confirm', async () => { + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { + options.onSuccess?.() + }) + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + fireEvent.click(await screen.findByTestId('confirm-template')) + + await waitFor(() => { + expect(mockHandleImportDSL).toHaveBeenCalledWith( + { mode: 'yaml-content', yaml_content: 'yaml-dsl-content' }, + expect.objectContaining({ onSuccess: expect.any(Function) }), + ) + expect(mockReplace).toHaveBeenCalled() + }) + }) + + it('should show DSL confirm modal when marketplace import is pending', async () => { + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => { + options.onPending?.() + }) + mockSearchParams = new URLSearchParams('template-id=tpl-42') + renderWithClient() + + fireEvent.click(await screen.findByTestId('confirm-template')) + + await waitFor(() => { + expect(screen.getByTestId('dsl-confirm-modal')).toBeInTheDocument() + expect(mockReplace).toHaveBeenCalled() + }) + }) + }) + describe('Styling', () => { it('should have overflow-y-auto class', () => { const { container } = renderWithClient() diff --git a/web/app/components/apps/import-from-marketplace-template-modal.tsx b/web/app/components/apps/import-from-marketplace-template-modal.tsx new file mode 100644 index 0000000000..a6a3dee8e4 --- /dev/null +++ b/web/app/components/apps/import-from-marketplace-template-modal.tsx @@ -0,0 +1,182 @@ +'use client' + +import { Button } from '@langgenius/dify-ui/button' +import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog' +import { toast } from '@langgenius/dify-ui/toast' +import { RiCloseLine } from '@remixicon/react' +import { useCallback, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { MARKETPLACE_API_PREFIX } from '@/config' +import { + fetchMarketplaceTemplateDSL, + useMarketplaceTemplateDetail, +} from '@/service/marketplace-templates' + +type ImportFromMarketplaceTemplateModalProps = { + templateId: string + onClose: () => void + onConfirm: (dslContent: string) => void +} + +const ImportFromMarketplaceTemplateModal = ({ + templateId, + onClose, + onConfirm, +}: ImportFromMarketplaceTemplateModalProps) => { + const { t } = useTranslation() + const { data, isLoading, isError } = useMarketplaceTemplateDetail(templateId) + const template = data?.data + const [importing, setImporting] = useState(false) + const isImportingRef = useRef(false) + + const CATEGORY_I18N_MAP: Record = useMemo(() => ({ + marketing: t('marketplace.template.category.marketing', { ns: 'app' }), + sales: t('marketplace.template.category.sales', { ns: 'app' }), + support: t('marketplace.template.category.support', { ns: 'app' }), + operations: t('marketplace.template.category.operations', { ns: 'app' }), + it: t('marketplace.template.category.it', { ns: 'app' }), + knowledge: t('marketplace.template.category.knowledge', { ns: 'app' }), + design: t('marketplace.template.category.design', { ns: 'app' }), + }), [t]) + + const translateCategory = useCallback((slug: string) => { + return CATEGORY_I18N_MAP[slug] ?? slug + }, [CATEGORY_I18N_MAP]) + + const handleConfirm = useCallback(async () => { + if (isImportingRef.current) + return + isImportingRef.current = true + setImporting(true) + try { + const dsl = await fetchMarketplaceTemplateDSL(templateId) + onConfirm(dsl) + } + catch { + toast.error(t('marketplace.template.importFailed', { ns: 'app' })) + } + finally { + setImporting(false) + isImportingRef.current = false + } + }, [templateId, onConfirm, t]) + + return ( + { + if (!open) + onClose() + }} + > + +
+ {t('marketplace.template.modalTitle', { ns: 'app' })} +
+ +
+
+ +
+ {isLoading && ( +
+
Loading...
+
+ )} + + {isError && ( +
+
+ {t('marketplace.template.fetchFailed', { ns: 'app' })} +
+
+ )} + + {template && ( +
+
+ {template.icon_file_key + ? ( + {template.template_name} + ) + : ( +
+ {template.icon || '📄'} +
+ )} +
+
{template.template_name}
+
+ + {t('marketplace.template.publishedBy', { ns: 'app' })} + {' '} + {template.publisher_unique_handle} + + · + + {t('marketplace.template.usageCount', { ns: 'app' })} + {' '} + {template.usage_count} + +
+
+
+ + {template.overview && ( +
+
+ {t('marketplace.template.overview', { ns: 'app' })} +
+
+ {template.overview} +
+
+ )} + + {template.categories.length > 0 && ( +
+ {template.categories.map(cat => ( + + {translateCategory(cat)} + + ))} +
+ )} +
+ )} +
+ +
+ + +
+
+
+ ) +} + +export default ImportFromMarketplaceTemplateModal diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index 9bf07e81e6..9d74968605 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -9,6 +9,7 @@ import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' import dynamic from '@/next/dynamic' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchAppDetail } from '@/service/explore' import { trackCreateApp } from '@/utils/create-app-tracking' import List from './list' @@ -16,9 +17,14 @@ import List from './list' const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) +const ImportFromMarketplaceTemplateModal = dynamic(() => import('./import-from-marketplace-template-modal'), { ssr: false }) const Apps = () => { const { t } = useTranslation() + const searchParams = useSearchParams() + const { replace } = useRouter() + const templateId = searchParams.get('template-id') + const templateDismissedRef = useRef(false) useDocumentTitle(t('menus.apps', { ns: 'common' })) useEducationInit() @@ -58,6 +64,14 @@ const Apps = () => { const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) + const handleCloseTemplateModal = useCallback(() => { + templateDismissedRef.current = true + const params = new URLSearchParams(searchParams.toString()) + params.delete('template-id') + const query = params.toString() + replace(query ? `?${query}` : window.location.pathname, { scroll: false }) + }, [searchParams, replace]) + const { handleImportDSL, handleImportDSLConfirm, @@ -74,6 +88,22 @@ const Apps = () => { }) }, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp]) + const handleMarketplaceTemplateConfirm = useCallback(async (dslContent: string) => { + await handleImportDSL({ + mode: DSLImportMode.YAML_CONTENT, + yaml_content: dslContent, + }, { + onSuccess: () => { + handleCloseTemplateModal() + onSuccess() + }, + onPending: () => { + handleCloseTemplateModal() + setShowDSLConfirmModal(true) + }, + }) + }, [handleImportDSL, handleCloseTemplateModal, onSuccess]) + const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, @@ -152,6 +182,14 @@ const Apps = () => { onHide={() => setIsShowCreateModal(false)} /> )} + + {templateId && !templateDismissedRef.current && ( + + )}
) diff --git a/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx index 7a02781c17..08ac245172 100644 --- a/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx +++ b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx @@ -156,7 +156,7 @@ describe('PanelContextmenu', () => { fireEvent.click(screen.getByText('common.run')) fireEvent.click(screen.getByText('common.pasteHere')) fireEvent.click(screen.getByText('export')) - fireEvent.click(screen.getByText('common.importDSL')) + fireEvent.click(screen.getByText('importApp')) clickAwayHandler?.() expect(mockHandleAddNote).toHaveBeenCalledTimes(1) diff --git a/web/app/components/workflow/panel-contextmenu.tsx b/web/app/components/workflow/panel-contextmenu.tsx index ffe88d3dc9..4478839077 100644 --- a/web/app/components/workflow/panel-contextmenu.tsx +++ b/web/app/components/workflow/panel-contextmenu.tsx @@ -137,7 +137,7 @@ const PanelContextmenu = () => { className="flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover" onClick={() => setShowImportDSLModal(true)} > - {t('common.importDSL', { ns: 'workflow' })} + {t('importApp', { ns: 'app' })}
diff --git a/web/app/components/workflow/update-dsl-modal.tsx b/web/app/components/workflow/update-dsl-modal.tsx index cfa9c995eb..549dee487f 100644 --- a/web/app/components/workflow/update-dsl-modal.tsx +++ b/web/app/components/workflow/update-dsl-modal.tsx @@ -205,7 +205,7 @@ const UpdateDSLModal = ({ onClose={onCancel} >
-
{t('common.importDSL', { ns: 'workflow' })}
+
{t('importApp', { ns: 'app' })}
diff --git a/web/app/page.tsx b/web/app/page.tsx index 65f8827e01..a866fd4c39 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -1,18 +1,23 @@ -import Loading from '@/app/components/base/loading' -import Link from '@/next/link' +import { redirect } from '@/next/navigation' -const Home = async () => { - return ( -
+type HomePageProps = { + searchParams: Promise> +} -
- -
- 🚀 -
-
-
- ) +const Home = async ({ searchParams }: HomePageProps) => { + const resolvedSearchParams = await searchParams + const urlSearchParams = new URLSearchParams() + Object.entries(resolvedSearchParams).forEach(([key, value]) => { + if (value === undefined) + return + if (Array.isArray(value)) { + value.forEach(item => urlSearchParams.append(key, item)) + return + } + urlSearchParams.set(key, value) + }) + const queryString = urlSearchParams.toString() + redirect(queryString ? `/apps?${queryString}` : '/apps') } export default Home diff --git a/web/app/signin/check-code/page.tsx b/web/app/signin/check-code/page.tsx index fb52e0b5b7..42024c561b 100644 --- a/web/app/signin/check-code/page.tsx +++ b/web/app/signin/check-code/page.tsx @@ -51,7 +51,7 @@ export default function CheckCode() { router.replace(`/signin/invite-settings?${searchParams.toString()}`) } else { - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } } diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx index 6feaf11426..30bc78666c 100644 --- a/web/app/signin/components/mail-and-password-auth.tsx +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -75,7 +75,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis router.replace(`/signin/invite-settings?${searchParams.toString()}`) } else { - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } } diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx index 7066ab041c..43ca96ab05 100644 --- a/web/app/signin/invite-settings/page.tsx +++ b/web/app/signin/invite-settings/page.tsx @@ -65,7 +65,7 @@ export default function InviteSettingsPage() { if (res.result === 'success') { // Tokens are now stored in cookies by the backend await setLocaleOnClient(language!, false) - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } } diff --git a/web/app/signin/normal-form.tsx b/web/app/signin/normal-form.tsx index 779aba5c9c..a32c7e9b3d 100644 --- a/web/app/signin/normal-form.tsx +++ b/web/app/signin/normal-form.tsx @@ -49,7 +49,7 @@ const NormalForm = () => { try { if (isLoggedIn) { setIsRedirecting(true) - const redirectUrl = resolvePostLoginRedirect() + const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') return } diff --git a/web/app/signin/utils/post-login-redirect.ts b/web/app/signin/utils/post-login-redirect.ts index a94fb2ad79..0015296a41 100644 --- a/web/app/signin/utils/post-login-redirect.ts +++ b/web/app/signin/utils/post-login-redirect.ts @@ -1,15 +1,63 @@ -let postLoginRedirect: string | null = null +import type { ReadonlyURLSearchParams } from '@/next/navigation' -export const setPostLoginRedirect = (value: string | null) => { - postLoginRedirect = value +const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending_redirect' +const REDIRECT_URL_KEY = 'redirect_url' + +type OAuthPendingRedirect = { + value?: string + expiry?: number } -export const resolvePostLoginRedirect = () => { - if (postLoginRedirect) { - const redirectUrl = postLoginRedirect - postLoginRedirect = null - return redirectUrl +const getCurrentUnixTimestamp = () => Math.floor(Date.now() / 1000) + +function removeOAuthPendingRedirect() { + try { + localStorage.removeItem(OAUTH_AUTHORIZE_PENDING_KEY) } - - return null + catch {} +} + +function getOAuthPendingRedirect(): string | null { + try { + const raw = localStorage.getItem(OAUTH_AUTHORIZE_PENDING_KEY) + if (!raw) + return null + removeOAuthPendingRedirect() + const item: OAuthPendingRedirect = JSON.parse(raw) + if (!item.value || typeof item.expiry !== 'number') + return null + return getCurrentUnixTimestamp() > item.expiry ? null : item.value + } + catch { + removeOAuthPendingRedirect() + return null + } +} + +export function setOAuthPendingRedirect(url: string, ttlSeconds: number = 300) { + try { + const item: OAuthPendingRedirect = { + value: url, + expiry: getCurrentUnixTimestamp() + ttlSeconds, + } + localStorage.setItem(OAUTH_AUTHORIZE_PENDING_KEY, JSON.stringify(item)) + } + catch {} +} + +export const resolvePostLoginRedirect = (searchParams?: ReadonlyURLSearchParams) => { + if (searchParams) { + const redirectUrl = searchParams.get(REDIRECT_URL_KEY) + if (redirectUrl) { + try { + removeOAuthPendingRedirect() + return decodeURIComponent(redirectUrl) + } + catch { + removeOAuthPendingRedirect() + return redirectUrl + } + } + } + return getOAuthPendingRedirect() } diff --git a/web/contract/marketplace.ts b/web/contract/marketplace.ts index 3573ba5c24..9f2475041e 100644 --- a/web/contract/marketplace.ts +++ b/web/contract/marketplace.ts @@ -1,5 +1,6 @@ import type { CollectionsAndPluginsSearchParams, MarketplaceCollection, PluginsSearchParams } from '@/app/components/plugins/marketplace/types' import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types' +import type { MarketplaceTemplate } from '@/types/marketplace-template' import { type } from '@orpc/contract' import { base } from './base' @@ -54,3 +55,15 @@ export const searchAdvancedContract = base body: Omit }>()) .output(type<{ data: PluginsFromMarketplaceResponse }>()) + +export const templateDetailContract = base + .route({ + path: '/templates/{templateId}', + method: 'GET', + }) + .input(type<{ + params: { + templateId: string + } + }>()) + .output(type<{ data: MarketplaceTemplate }>()) diff --git a/web/contract/router.ts b/web/contract/router.ts index 45f514a820..086b94f248 100644 --- a/web/contract/router.ts +++ b/web/contract/router.ts @@ -42,12 +42,13 @@ import { workflowDraftUpdateFeaturesContract, } from './console/workflow' import { workflowCommentContracts } from './console/workflow-comment' -import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace' +import { collectionPluginsContract, collectionsContract, searchAdvancedContract, templateDetailContract } from './marketplace' export const marketplaceRouterContract = { collections: collectionsContract, collectionPlugins: collectionPluginsContract, searchAdvanced: searchAdvancedContract, + templateDetail: templateDetailContract, } export type MarketPlaceInputs = InferContractRouterInputs diff --git a/web/i18n/en-US/app.json b/web/i18n/en-US/app.json index 0ad608d53c..0efa33de07 100644 --- a/web/i18n/en-US/app.json +++ b/web/i18n/en-US/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "Emoji", "iconPicker.image": "Image", "iconPicker.ok": "OK", + "importApp": "Import App", "importDSL": "Import DSL file", "importFromDSL": "Import from DSL", "importFromDSLFile": "From DSL file", "importFromDSLUrl": "From URL", "importFromDSLUrlPlaceholder": "Paste DSL link here", "join": "Join the community", + "marketplace.template.categories": "Categories", + "marketplace.template.category.design": "Design", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "Knowledge", + "marketplace.template.category.marketing": "Marketing", + "marketplace.template.category.operations": "Operations", + "marketplace.template.category.sales": "Sales", + "marketplace.template.category.support": "Support", + "marketplace.template.fetchFailed": "Failed to fetch template", + "marketplace.template.importConfirm": "Import", + "marketplace.template.importFailed": "Failed to import template", + "marketplace.template.modalTitle": "Import from Marketplace", + "marketplace.template.overview": "Overview", + "marketplace.template.publishedBy": "By", + "marketplace.template.usageCount": "Usage", + "marketplace.template.viewOnMarketplace": "View on Marketplace", "maxActiveRequests": "Max concurrent requests", "maxActiveRequestsPlaceholder": "Enter 0 for unlimited", "maxActiveRequestsTip": "Maximum number of concurrent active requests per app (0 for unlimited)", diff --git a/web/i18n/en-US/workflow.json b/web/i18n/en-US/workflow.json index 23516274a9..3bb285d501 100644 --- a/web/i18n/en-US/workflow.json +++ b/web/i18n/en-US/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "Enter content in the box below to start debugging the Chatbot", "common.processData": "Process Data", "common.publish": "Publish", + "common.publishToMarketplace": "Publish to Marketplace", + "common.publishToMarketplaceFailed": "Failed to publish to Marketplace", "common.publishUpdate": "Publish Update", "common.published": "Published", "common.publishedAt": "Published", + "common.publishingToMarketplace": "Publishing...", "common.redo": "Redo", "common.restart": "Restart", "common.restore": "Restore", diff --git a/web/i18n/zh-Hans/app.json b/web/i18n/zh-Hans/app.json index 278a1b782d..8f46a4433e 100644 --- a/web/i18n/zh-Hans/app.json +++ b/web/i18n/zh-Hans/app.json @@ -118,12 +118,29 @@ "iconPicker.emoji": "表情符号", "iconPicker.image": "图片", "iconPicker.ok": "确认", + "importApp": "导入应用", "importDSL": "导入 DSL 文件", "importFromDSL": "导入 DSL", "importFromDSLFile": "文件", "importFromDSLUrl": "URL", "importFromDSLUrlPlaceholder": "输入 DSL 文件的 URL", "join": "参与社区", + "marketplace.template.categories": "分类", + "marketplace.template.category.design": "设计", + "marketplace.template.category.it": "IT", + "marketplace.template.category.knowledge": "知识", + "marketplace.template.category.marketing": "营销", + "marketplace.template.category.operations": "运营", + "marketplace.template.category.sales": "销售", + "marketplace.template.category.support": "支持", + "marketplace.template.fetchFailed": "获取模板失败", + "marketplace.template.importConfirm": "导入", + "marketplace.template.importFailed": "导入模板失败", + "marketplace.template.modalTitle": "从市场导入", + "marketplace.template.overview": "概述", + "marketplace.template.publishedBy": "来自", + "marketplace.template.usageCount": "使用次数", + "marketplace.template.viewOnMarketplace": "在市场查看", "maxActiveRequests": "最大活跃请求数", "maxActiveRequestsPlaceholder": "0 表示不限制", "maxActiveRequestsTip": "当前应用的最大活跃请求数(0 表示不限制)", diff --git a/web/i18n/zh-Hans/workflow.json b/web/i18n/zh-Hans/workflow.json index 593a88f4db..ac3a27af11 100644 --- a/web/i18n/zh-Hans/workflow.json +++ b/web/i18n/zh-Hans/workflow.json @@ -229,9 +229,12 @@ "common.previewPlaceholder": "在下面的框中输入内容开始调试聊天机器人", "common.processData": "数据处理", "common.publish": "发布", + "common.publishToMarketplace": "发布到市场", + "common.publishToMarketplaceFailed": "发布到市场失败", "common.publishUpdate": "发布更新", "common.published": "已发布", "common.publishedAt": "发布于", + "common.publishingToMarketplace": "发布中...", "common.redo": "重做", "common.restart": "重新开始", "common.restore": "恢复", diff --git a/web/next/navigation.ts b/web/next/navigation.ts index ec7c112645..f8ff821d1f 100644 --- a/web/next/navigation.ts +++ b/web/next/navigation.ts @@ -1,4 +1,5 @@ export { + redirect, useParams, usePathname, useRouter, @@ -6,3 +7,4 @@ export { useSelectedLayoutSegment, useSelectedLayoutSegments, } from 'next/navigation' +export type { ReadonlyURLSearchParams } from 'next/navigation' diff --git a/web/service/__tests__/base.spec.ts b/web/service/__tests__/base.spec.ts new file mode 100644 index 0000000000..a4d1dcbfe7 --- /dev/null +++ b/web/service/__tests__/base.spec.ts @@ -0,0 +1,68 @@ +import { buildSigninUrlWithRedirect } from '../base' + +vi.mock('@/utils/var', () => ({ + basePath: '/app', + API_PREFIX: '/console/api', + PUBLIC_API_PREFIX: '/api', + IS_CE_EDITION: false, +})) + +describe('buildSigninUrlWithRedirect', () => { + const originalLocation = globalThis.location + + beforeEach(() => { + Object.defineProperty(globalThis, 'location', { + value: { + origin: 'https://example.com', + pathname: '/apps', + href: 'https://example.com/apps', + }, + writable: true, + configurable: true, + }) + }) + + afterEach(() => { + Object.defineProperty(globalThis, 'location', { + value: originalLocation, + writable: true, + configurable: true, + }) + }) + + it('should return plain signin URL for non-OAuth pages', () => { + const url = buildSigninUrlWithRedirect() + expect(url).toBe('https://example.com/app/signin') + }) + + it('should append redirect_url for OAuth authorize pages', () => { + const oauthHref = 'https://example.com/account/oauth/authorize?client_id=abc&state=xyz' + Object.defineProperty(globalThis, 'location', { + value: { + origin: 'https://example.com', + pathname: '/account/oauth/authorize', + href: oauthHref, + }, + writable: true, + configurable: true, + }) + + const url = buildSigninUrlWithRedirect() + expect(url).toBe(`https://example.com/app/signin?redirect_url=${encodeURIComponent(oauthHref)}`) + }) + + it('should not include redirect_url for other paths containing partial match', () => { + Object.defineProperty(globalThis, 'location', { + value: { + origin: 'https://example.com', + pathname: '/settings/oauth', + href: 'https://example.com/settings/oauth', + }, + writable: true, + configurable: true, + }) + + const url = buildSigninUrlWithRedirect() + expect(url).toBe('https://example.com/app/signin') + }) +}) diff --git a/web/service/apps.ts b/web/service/apps.ts index b6a5386fe0..d2c6593a34 100644 --- a/web/service/apps.ts +++ b/web/service/apps.ts @@ -192,3 +192,11 @@ export const updateTracingConfig = ({ appId, body }: { appId: string, body: Trac export const removeTracingConfig = ({ appId, provider }: { appId: string, provider: TracingProvider }): Promise => { return del(`/apps/${appId}/trace-config?tracing_provider=${provider}`) } + +type PublishToCreatorsPlatformResponse = { + redirect_url: string +} + +export const publishToCreatorsPlatform = ({ appID }: { appID: string }): Promise => { + return post(`apps/${appID}/publish-to-creators-platform`, { body: {} }) +} diff --git a/web/service/base.ts b/web/service/base.ts index 64d13ef59a..d1ef06c314 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -140,6 +140,20 @@ function jumpTo(url: string) { globalThis.location.href = url } +const OAUTH_AUTHORIZE_PATH = '/account/oauth/authorize' + +export const buildSigninUrlWithRedirect = (): string => { + const loginUrl = `${globalThis.location.origin}${basePath}/signin` + + // Only preserve redirect URL for OAuth authorize pages + if (globalThis.location.pathname.includes(OAUTH_AUTHORIZE_PATH)) { + const currentUrl = globalThis.location.href + return `${loginUrl}?redirect_url=${encodeURIComponent(currentUrl)}` + } + + return loginUrl +} + function unicodeToChar(text: string) { if (!text) return '' @@ -795,14 +809,14 @@ export const request = async(url: string, options = {}, otherOptions?: IOther if (refreshErr === null) return baseFetch(url, options, otherOptionsForBaseFetch) if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) { - jumpTo(loginUrl) + jumpTo(buildSigninUrlWithRedirect()) return Promise.reject(err) } if (!silent) { toast.error(message) return Promise.reject(err) } - jumpTo(loginUrl) + jumpTo(buildSigninUrlWithRedirect()) return Promise.reject(err) } else { diff --git a/web/service/marketplace-templates.ts b/web/service/marketplace-templates.ts new file mode 100644 index 0000000000..d9ff7f314f --- /dev/null +++ b/web/service/marketplace-templates.ts @@ -0,0 +1,18 @@ +import { useQuery } from '@tanstack/react-query' +import { MARKETPLACE_API_PREFIX } from '@/config' +import { marketplaceQuery } from './client' + +export const useMarketplaceTemplateDetail = (templateId: string | null) => { + return useQuery({ + ...marketplaceQuery.templateDetail.queryOptions({ input: { params: { templateId: templateId ?? '' } } }), + enabled: !!templateId, + }) +} + +export const fetchMarketplaceTemplateDSL = async (templateId: string): Promise => { + const url = `${MARKETPLACE_API_PREFIX}/templates/${templateId}/dsl` + const response = await fetch(url) + if (!response.ok) + throw new Error(`Failed to fetch DSL: ${response.statusText}`) + return response.text() +} diff --git a/web/types/feature.ts b/web/types/feature.ts index 635221f2be..77d4045318 100644 --- a/web/types/feature.ts +++ b/web/types/feature.ts @@ -64,6 +64,7 @@ export type SystemFeatures = { allow_email_code_login: boolean allow_email_password_login: boolean } + enable_creators_platform: boolean enable_trial_app: boolean enable_explore_banner: boolean } @@ -108,6 +109,7 @@ export const defaultSystemFeatures: SystemFeatures = { allow_email_code_login: false, allow_email_password_login: false, }, + enable_creators_platform: false, enable_trial_app: false, enable_explore_banner: false, } diff --git a/web/types/marketplace-template.ts b/web/types/marketplace-template.ts new file mode 100644 index 0000000000..ac2b7cb2aa --- /dev/null +++ b/web/types/marketplace-template.ts @@ -0,0 +1,11 @@ +export type MarketplaceTemplate = { + id: string + template_name: string + overview: string + icon: string + icon_background: string + icon_file_key: string + publisher_unique_handle: string + usage_count: number + categories: string[] +}