mirror of
https://github.com/langgenius/dify.git
synced 2026-04-20 18:57:19 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -27,7 +27,7 @@ from services.feature_service import BrandingModel
|
||||
class MockEmailRenderer:
|
||||
"""Mock implementation of EmailRenderer protocol"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.rendered_templates: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
def render_template(self, template_path: str, **context: Any) -> str:
|
||||
@ -39,7 +39,7 @@ class MockEmailRenderer:
|
||||
class MockBrandingService:
|
||||
"""Mock implementation of BrandingService protocol"""
|
||||
|
||||
def __init__(self, enabled: bool = False, application_title: str = "Dify") -> None:
|
||||
def __init__(self, enabled: bool = False, application_title: str = "Dify"):
|
||||
self.enabled = enabled
|
||||
self.application_title = application_title
|
||||
|
||||
@ -54,10 +54,10 @@ class MockBrandingService:
|
||||
class MockEmailSender:
|
||||
"""Mock implementation of EmailSender protocol"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.sent_emails: list[dict[str, str]] = []
|
||||
|
||||
def send_email(self, to: str, subject: str, html_content: str) -> None:
|
||||
def send_email(self, to: str, subject: str, html_content: str):
|
||||
"""Mock send_email that records sent emails"""
|
||||
self.sent_emails.append(
|
||||
{
|
||||
@ -134,7 +134,7 @@ class TestEmailI18nService:
|
||||
email_service: EmailI18nService,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending email with English language"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
@ -162,7 +162,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending email with Chinese language"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
@ -181,7 +181,7 @@ class TestEmailI18nService:
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending email with branding enabled"""
|
||||
# Create branding service with branding enabled
|
||||
branding_service = MockBrandingService(enabled=True, application_title="MyApp")
|
||||
@ -215,7 +215,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test language fallback to English when requested language not available"""
|
||||
# Request invite member in Chinese (not configured)
|
||||
email_service.send_email(
|
||||
@ -233,7 +233,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test unknown language code falls back to English"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
@ -246,13 +246,50 @@ class TestEmailI18nService:
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "Reset Your Dify Password"
|
||||
|
||||
def test_subject_format_keyerror_fallback_path(
|
||||
self,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Trigger subject KeyError and cover except branch."""
|
||||
# Config with subject that references an unknown key (no {application_title} to avoid second format)
|
||||
config = EmailI18nConfig(
|
||||
templates={
|
||||
EmailType.INVITE_MEMBER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Invite: {unknown_placeholder}",
|
||||
template_path="invite_member_en.html",
|
||||
branded_template_path="branded/invite_member_en.html",
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
branding_service = MockBrandingService(enabled=False)
|
||||
service = EmailI18nService(
|
||||
config=config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
# Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback
|
||||
service.send_email(
|
||||
email_type=EmailType.INVITE_MEMBER,
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
)
|
||||
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
# Subject is left unformatted due to KeyError fallback path without application_title
|
||||
assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}"
|
||||
|
||||
def test_send_change_email_old_phase(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending change email for old email verification"""
|
||||
# Add change email templates to config
|
||||
email_config.templates[EmailType.CHANGE_EMAIL_OLD] = {
|
||||
@ -290,7 +327,7 @@ class TestEmailI18nService:
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending change email for new email verification"""
|
||||
# Add change email templates to config
|
||||
email_config.templates[EmailType.CHANGE_EMAIL_NEW] = {
|
||||
@ -325,7 +362,7 @@ class TestEmailI18nService:
|
||||
def test_send_change_email_invalid_phase(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending change email with invalid phase raises error"""
|
||||
with pytest.raises(ValueError, match="Invalid phase: invalid_phase"):
|
||||
email_service.send_change_email(
|
||||
@ -339,7 +376,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending raw email to single recipient"""
|
||||
email_service.send_raw_email(
|
||||
to="test@example.com",
|
||||
@ -357,7 +394,7 @@ class TestEmailI18nService:
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> None:
|
||||
):
|
||||
"""Test sending raw email to multiple recipients"""
|
||||
recipients = ["user1@example.com", "user2@example.com", "user3@example.com"]
|
||||
|
||||
@ -378,7 +415,7 @@ class TestEmailI18nService:
|
||||
def test_get_template_missing_email_type(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
) -> None:
|
||||
):
|
||||
"""Test getting template for missing email type raises error"""
|
||||
with pytest.raises(ValueError, match="No templates configured for email type"):
|
||||
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
|
||||
@ -386,7 +423,7 @@ class TestEmailI18nService:
|
||||
def test_get_template_missing_language_and_english(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
) -> None:
|
||||
):
|
||||
"""Test error when neither requested language nor English fallback exists"""
|
||||
# Add template without English fallback
|
||||
email_config.templates[EmailType.EMAIL_CODE_LOGIN] = {
|
||||
@ -407,7 +444,7 @@ class TestEmailI18nService:
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
) -> None:
|
||||
):
|
||||
"""Test subject templating with custom variables"""
|
||||
# Add template with variable in subject
|
||||
email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = {
|
||||
@ -437,7 +474,7 @@ class TestEmailI18nService:
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "You are now the owner of My Workspace"
|
||||
|
||||
def test_email_language_from_language_code(self) -> None:
|
||||
def test_email_language_from_language_code(self):
|
||||
"""Test EmailLanguage.from_language_code method"""
|
||||
assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS
|
||||
assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US
|
||||
@ -448,7 +485,7 @@ class TestEmailI18nService:
|
||||
class TestEmailI18nIntegration:
|
||||
"""Integration tests for email i18n components"""
|
||||
|
||||
def test_create_default_email_config(self) -> None:
|
||||
def test_create_default_email_config(self):
|
||||
"""Test creating default email configuration"""
|
||||
config = create_default_email_config()
|
||||
|
||||
@ -476,7 +513,7 @@ class TestEmailI18nIntegration:
|
||||
assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD]
|
||||
assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER]
|
||||
|
||||
def test_get_email_i18n_service(self) -> None:
|
||||
def test_get_email_i18n_service(self):
|
||||
"""Test getting global email i18n service instance"""
|
||||
service1 = get_email_i18n_service()
|
||||
service2 = get_email_i18n_service()
|
||||
@ -484,7 +521,7 @@ class TestEmailI18nIntegration:
|
||||
# Should return the same instance
|
||||
assert service1 is service2
|
||||
|
||||
def test_flask_email_renderer(self) -> None:
|
||||
def test_flask_email_renderer(self):
|
||||
"""Test FlaskEmailRenderer implementation"""
|
||||
renderer = FlaskEmailRenderer()
|
||||
|
||||
@ -494,7 +531,7 @@ class TestEmailI18nIntegration:
|
||||
with pytest.raises(TemplateNotFound):
|
||||
renderer.render_template("test.html", foo="bar")
|
||||
|
||||
def test_flask_mail_sender_not_initialized(self) -> None:
|
||||
def test_flask_mail_sender_not_initialized(self):
|
||||
"""Test FlaskMailSender when mail is not initialized"""
|
||||
sender = FlaskMailSender()
|
||||
|
||||
@ -514,7 +551,7 @@ class TestEmailI18nIntegration:
|
||||
# Restore original mail
|
||||
libs.email_i18n.mail = original_mail
|
||||
|
||||
def test_flask_mail_sender_initialized(self) -> None:
|
||||
def test_flask_mail_sender_initialized(self):
|
||||
"""Test FlaskMailSender when mail is initialized"""
|
||||
sender = FlaskMailSender()
|
||||
|
||||
|
||||
122
api/tests/unit_tests/libs/test_external_api.py
Normal file
122
api/tests/unit_tests/libs/test_external_api.py
Normal file
@ -0,0 +1,122 @@
|
||||
from flask import Blueprint, Flask
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, Unauthorized
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
|
||||
def _create_api_app():
|
||||
app = Flask(__name__)
|
||||
bp = Blueprint("t", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/bad-request")
|
||||
class Bad(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise BadRequest("invalid input")
|
||||
|
||||
@api.route("/unauth")
|
||||
class Unauth(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise Unauthorized("auth required")
|
||||
|
||||
@api.route("/value-error")
|
||||
class ValErr(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise ValueError("boom")
|
||||
|
||||
@api.route("/quota")
|
||||
class Quota(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise AppInvokeQuotaExceededError("quota exceeded")
|
||||
|
||||
@api.route("/general")
|
||||
class Gen(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise RuntimeError("oops")
|
||||
|
||||
# Note: We avoid altering default_mediatype to keep normal error paths
|
||||
|
||||
# Special 400 message rewrite
|
||||
@api.route("/json-empty")
|
||||
class JsonEmpty(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
e = BadRequest()
|
||||
# Force the specific message the handler rewrites
|
||||
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
||||
raise e
|
||||
|
||||
# 400 mapping payload path
|
||||
@api.route("/param-errors")
|
||||
class ParamErrors(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
e = BadRequest()
|
||||
# Coerce a mapping description to trigger param error shaping
|
||||
e.description = {"field": "is required"} # type: ignore[assignment]
|
||||
raise e
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
return app
|
||||
|
||||
|
||||
def test_external_api_error_handlers_basic_paths():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# 400
|
||||
res = client.get("/api/bad-request")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "bad_request"
|
||||
assert data["status"] == 400
|
||||
|
||||
# 401
|
||||
res = client.get("/api/unauth")
|
||||
assert res.status_code == 401
|
||||
assert "WWW-Authenticate" in res.headers
|
||||
|
||||
# 400 ValueError
|
||||
res = client.get("/api/value-error")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["code"] == "invalid_param"
|
||||
|
||||
# 500 general
|
||||
res = client.get("/api/general")
|
||||
assert res.status_code == 500
|
||||
assert res.get_json()["status"] == 500
|
||||
|
||||
|
||||
def test_external_api_json_message_and_bad_request_rewrite():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# JSON empty special rewrite
|
||||
res = client.get("/api/json-empty")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
55
api/tests/unit_tests/libs/test_file_utils.py
Normal file
55
api/tests/unit_tests/libs/test_file_utils.py
Normal file
@ -0,0 +1,55 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
|
||||
def test_search_file_upwards_found_in_parent(tmp_path: Path):
|
||||
base = tmp_path / "a" / "b" / "c"
|
||||
base.mkdir(parents=True)
|
||||
|
||||
target = tmp_path / "a" / "target.txt"
|
||||
target.write_text("ok", encoding="utf-8")
|
||||
|
||||
found = search_file_upwards(base, "target.txt", max_search_parent_depth=5)
|
||||
assert found == target
|
||||
|
||||
|
||||
def test_search_file_upwards_found_in_current(tmp_path: Path):
|
||||
base = tmp_path / "x"
|
||||
base.mkdir()
|
||||
target = base / "here.txt"
|
||||
target.write_text("x", encoding="utf-8")
|
||||
|
||||
found = search_file_upwards(base, "here.txt", max_search_parent_depth=1)
|
||||
assert found == target
|
||||
|
||||
|
||||
def test_search_file_upwards_not_found_raises(tmp_path: Path):
|
||||
base = tmp_path / "m" / "n"
|
||||
base.mkdir(parents=True)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
search_file_upwards(base, "missing.txt", max_search_parent_depth=3)
|
||||
# error message should contain file name and base path
|
||||
msg = str(exc.value)
|
||||
assert "missing.txt" in msg
|
||||
assert str(base) in msg
|
||||
|
||||
|
||||
def test_search_file_upwards_root_breaks_and_raises():
|
||||
# Using filesystem root triggers the 'break' branch (parent == current)
|
||||
with pytest.raises(ValueError):
|
||||
search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1)
|
||||
|
||||
|
||||
def test_search_file_upwards_depth_limit_raises(tmp_path: Path):
|
||||
base = tmp_path / "a" / "b" / "c"
|
||||
base.mkdir(parents=True)
|
||||
target = tmp_path / "a" / "target.txt"
|
||||
target.write_text("ok", encoding="utf-8")
|
||||
# The file is 2 levels up from `c` (in `a`), but search depth is only 2.
|
||||
# The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3).
|
||||
# So, this should not find the file and should raise an error.
|
||||
with pytest.raises(ValueError):
|
||||
search_file_upwards(base, "target.txt", max_search_parent_depth=2)
|
||||
@ -1,6 +1,5 @@
|
||||
import contextvars
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -29,7 +28,7 @@ def login_app(app: Flask) -> Flask:
|
||||
login_manager.init_app(app)
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id: str) -> Optional[User]:
|
||||
def load_user(user_id: str) -> User | None:
|
||||
if user_id == "test_user":
|
||||
return User("test_user")
|
||||
return None
|
||||
|
||||
88
api/tests/unit_tests/libs/test_json_in_md_parser.py
Normal file
88
api/tests/unit_tests/libs/test_json_in_md_parser.py
Normal file
@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from libs.json_in_md_parser import (
|
||||
parse_and_check_json_markdown,
|
||||
parse_json_markdown,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_json_markdown_triple_backticks_json():
|
||||
src = """
|
||||
```json
|
||||
{"a": 1, "b": "x"}
|
||||
```
|
||||
"""
|
||||
assert parse_json_markdown(src) == {"a": 1, "b": "x"}
|
||||
|
||||
|
||||
def test_parse_json_markdown_triple_backticks_generic():
|
||||
src = """
|
||||
```
|
||||
{"k": [1, 2, 3]}
|
||||
```
|
||||
"""
|
||||
assert parse_json_markdown(src) == {"k": [1, 2, 3]}
|
||||
|
||||
|
||||
def test_parse_json_markdown_single_backticks():
|
||||
src = '`{"x": true}`'
|
||||
assert parse_json_markdown(src) == {"x": True}
|
||||
|
||||
|
||||
def test_parse_json_markdown_braces_only():
|
||||
src = ' {\n \t"ok": "yes"\n} '
|
||||
assert parse_json_markdown(src) == {"ok": "yes"}
|
||||
|
||||
|
||||
def test_parse_json_markdown_not_found():
|
||||
with pytest.raises(ValueError):
|
||||
parse_json_markdown("no json here")
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_missing_key():
|
||||
src = """
|
||||
```
|
||||
{"present": 1}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as exc:
|
||||
parse_and_check_json_markdown(src, ["present", "missing"])
|
||||
assert "expected key `missing`" in str(exc.value)
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_invalid_json():
|
||||
src = """
|
||||
```json
|
||||
{invalid json}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as exc:
|
||||
parse_and_check_json_markdown(src, [])
|
||||
assert "got invalid json object" in str(exc.value)
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_success():
|
||||
src = """
|
||||
```json
|
||||
{"present": 1, "other": 2}
|
||||
```
|
||||
"""
|
||||
obj = parse_and_check_json_markdown(src, ["present"])
|
||||
assert obj == {"present": 1, "other": 2}
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_multiple_blocks_fails():
|
||||
src = """
|
||||
```json
|
||||
{"a": 1}
|
||||
```
|
||||
Some text
|
||||
```json
|
||||
{"b": 2}
|
||||
```
|
||||
"""
|
||||
# The current implementation is greedy and will match from the first
|
||||
# opening fence to the last closing fence, causing JSON decode failure.
|
||||
with pytest.raises(OutputParserError):
|
||||
parse_and_check_json_markdown(src, [])
|
||||
19
api/tests/unit_tests/libs/test_oauth_base.py
Normal file
19
api/tests/unit_tests/libs/test_oauth_base.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from libs.oauth import OAuth
|
||||
|
||||
|
||||
def test_oauth_base_methods_raise_not_implemented():
|
||||
oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_authorization_url()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_access_token("code")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_raw_user_info("token")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({}) # type: ignore[name-defined]
|
||||
@ -1,8 +1,8 @@
|
||||
import urllib.parse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
|
||||
@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
with pytest.raises(httpx.RequestError):
|
||||
oauth.get_raw_user_info("test_token")
|
||||
|
||||
|
||||
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type",
|
||||
[
|
||||
requests.exceptions.HTTPError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
httpx.HTTPError,
|
||||
httpx.ConnectError,
|
||||
httpx.TimeoutException,
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
|
||||
25
api/tests/unit_tests/libs/test_orjson.py
Normal file
25
api/tests/unit_tests/libs/test_orjson.py
Normal file
@ -0,0 +1,25 @@
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from libs.orjson import orjson_dumps
|
||||
|
||||
|
||||
def test_orjson_dumps_round_trip_basic():
|
||||
obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}}
|
||||
s = orjson_dumps(obj)
|
||||
assert orjson.loads(s) == obj
|
||||
|
||||
|
||||
def test_orjson_dumps_with_unicode_and_indent():
|
||||
obj = {"msg": "你好,Dify"}
|
||||
s = orjson_dumps(obj, option=orjson.OPT_INDENT_2)
|
||||
# contains indentation newline/spaces
|
||||
assert "\n" in s
|
||||
assert orjson.loads(s) == obj
|
||||
|
||||
|
||||
def test_orjson_dumps_non_utf8_encoding_fails():
|
||||
obj = {"msg": "你好"}
|
||||
# orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails.
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
orjson_dumps(obj, encoding="ascii")
|
||||
@ -4,7 +4,7 @@ from Crypto.PublicKey import RSA
|
||||
from libs import gmpy2_pkcs10aep_cipher
|
||||
|
||||
|
||||
def test_gmpy2_pkcs10aep_cipher() -> None:
|
||||
def test_gmpy2_pkcs10aep_cipher():
|
||||
rsa_key_pair = pyrsa.newkeys(2048)
|
||||
public_key = rsa_key_pair[0].save_pkcs1()
|
||||
private_key = rsa_key_pair[1].save_pkcs1()
|
||||
|
||||
53
api/tests/unit_tests/libs/test_sendgrid_client.py
Normal file
53
api/tests/unit_tests/libs/test_sendgrid_client.py
Normal file
@ -0,0 +1,53 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from python_http_client.exceptions import UnauthorizedError
|
||||
|
||||
from libs.sendgrid import SendGridClient
|
||||
|
||||
|
||||
def _mail(to: str = "user@example.com") -> dict:
|
||||
return {"to": to, "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_success(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
# nested attribute access: client.mail.send.post
|
||||
mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
sg.send(_mail())
|
||||
|
||||
mock_client_cls.assert_called_once()
|
||||
mock_client.client.mail.send.post.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock):
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(ValueError):
|
||||
sg.send(_mail(to=""))
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(UnauthorizedError):
|
||||
sg.send(_mail())
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = TimeoutError("timeout")
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(TimeoutError):
|
||||
sg.send(_mail())
|
||||
100
api/tests/unit_tests/libs/test_smtp_client.py
Normal file
100
api/tests/unit_tests/libs/test_smtp_client.py
Normal file
@ -0,0 +1,100 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.smtp import SMTPClient
|
||||
|
||||
|
||||
def _mail() -> dict:
|
||||
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=587,
|
||||
username="user",
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=True,
|
||||
)
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
|
||||
assert mock_smtp.ehlo.call_count == 2
|
||||
mock_smtp.starttls.assert_called_once()
|
||||
mock_smtp.login.assert_called_once_with("user", "pass")
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP_SSL")
|
||||
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
|
||||
# Cover SMTP_SSL branch and TimeoutError handling
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = TimeoutError("timeout")
|
||||
mock_smtp_ssl_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=465,
|
||||
username="",
|
||||
password="",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=False,
|
||||
)
|
||||
with pytest.raises(TimeoutError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = RuntimeError("oops")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
with pytest.raises(RuntimeError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
|
||||
# Ensure we hit the specific SMTPException except branch
|
||||
import smtplib
|
||||
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.login.side_effect = smtplib.SMTPException("login-fail")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=25,
|
||||
username="user", # non-empty to trigger login
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
)
|
||||
with pytest.raises(smtplib.SMTPException):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
Reference in New Issue
Block a user