Files
ragflow/test/unit_test/common/test_misc_utils.py
dale053 6ab25bf715 fix: block SSRF in misc_utils.download_img for OAuth avatars (#14868)
### What problem does this PR solve?

Closes #14865

`download_img` in `common/misc_utils.py` is used for OAuth avatar URLs.
The previous implementation called `async_request` from
`common.http_client`, which followed redirects without re-validating
each hop and did not apply the same SSRF protections as this path needs.
That made it possible to reach non-public or disallowed targets (for
example via redirects or unsafe URLs) when fetching avatars.

This change replaces that flow with an explicit, bounded fetch: each URL
(including every redirect target) is checked with
`common.ssrf_guard.assert_url_is_safe`, DNS is pinned with
`pin_dns_global`, `httpx` streams the body with `follow_redirects=False`
and a manual redirect loop (capped by
`RAGFLOW_OAUTH_AVATAR_MAX_REDIRECTS`), and total response size is capped
(`RAGFLOW_OAUTH_AVATAR_MAX_BYTES`). Timeouts, proxy, and user agent
align with `HTTP_CLIENT_*` env vars without importing `http_client`, so
lightweight tests stay simple.

Unit tests cover empty/None URLs, loopback, cloud metadata-style
addresses, and disallowed schemes so SSRF regressions are caught early.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-05-22 12:12:04 +08:00

497 lines
16 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import hashlib
import sys
import types
import uuid
from contextlib import contextmanager
from unittest.mock import patch
from common import ssrf_guard
from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int
class _Hdr:
def __init__(self, mapping: dict[str, str]):
self._m = {k.lower(): v for k, v in mapping.items()}
def get(self, key: str, default=None):
return self._m.get(key.lower(), default)
class _MockStreamResp:
def __init__(self, status_code: int, *, location: str | None = None, body: bytes = b""):
self.status_code = status_code
hdrs: dict[str, str] = {}
if location is not None:
hdrs["Location"] = location
if body:
hdrs.setdefault("Content-Type", "image/jpeg")
self.headers = _Hdr(hdrs)
self._body = body
async def aclose(self):
return None
async def aiter_bytes(self):
if self._body:
yield self._body
class _FakeStreamCtx:
def __init__(self, resp: _MockStreamResp):
self._resp = resp
async def __aenter__(self):
return self._resp
async def __aexit__(self, exc_type, exc, tb):
return None
class _FakeAsyncClient:
def __init__(self, responses: list[_MockStreamResp]):
self._responses = responses
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
def stream(self, method, url, headers=None):
if not self._responses:
return _FakeStreamCtx(_MockStreamResp(404))
return _FakeStreamCtx(self._responses.pop(0))
@contextmanager
def _fake_httpx_sys_modules(client):
"""Minimal ``httpx`` stub so ``download_img`` can be exercised without real httpx."""
saved = sys.modules.get("httpx")
fake = types.ModuleType("httpx")
class _Timeout:
def __init__(self, *_a, **_kw):
pass
fake.Timeout = _Timeout
def _AsyncClient(*_a, **_kw):
return client
fake.AsyncClient = _AsyncClient
sys.modules["httpx"] = fake
try:
yield
finally:
if saved is not None:
sys.modules["httpx"] = saved
else:
sys.modules.pop("httpx", None)
class TestGetUuid:
"""Test cases for get_uuid function"""
def test_returns_string(self):
"""Test that function returns a string"""
result = get_uuid()
assert isinstance(result, str)
def test_hex_format(self):
"""Test that returned string is in hex format"""
result = get_uuid()
# UUID v1 hex should be 32 characters (without dashes)
assert len(result) == 32
# Should only contain hexadecimal characters
assert all(c in '0123456789abcdef' for c in result)
def test_no_dashes_in_result(self):
"""Test that result contains no dashes"""
result = get_uuid()
assert '-' not in result
def test_unique_results(self):
"""Test that multiple calls return different UUIDs"""
results = [get_uuid() for _ in range(10)]
# All results should be unique
assert len(results) == len(set(results))
# All should be valid hex strings of correct length
for result in results:
assert len(result) == 32
assert all(c in '0123456789abcdef' for c in result)
def test_valid_uuid_structure(self):
"""Test that the hex string can be converted back to UUID"""
result = get_uuid()
# Should be able to create UUID from the hex string
reconstructed_uuid = uuid.UUID(hex=result)
assert isinstance(reconstructed_uuid, uuid.UUID)
# The hex representation should match the original
assert reconstructed_uuid.hex == result
def test_uuid1_specific_characteristics(self):
"""Test that UUID v1 characteristics are present"""
result = get_uuid()
uuid_obj = uuid.UUID(hex=result)
# UUID v1 should have version 1
assert uuid_obj.version == 1
# Variant should be RFC 4122
assert uuid_obj.variant == 'specified in RFC 4122'
def test_result_length_consistency(self):
"""Test that all generated UUIDs have consistent length"""
for _ in range(100):
result = get_uuid()
assert len(result) == 32
def test_hex_characters_only(self):
"""Test that only valid hex characters are used"""
for _ in range(100):
result = get_uuid()
# Should only contain lowercase hex characters (UUID hex is lowercase)
assert result.islower()
assert all(c in '0123456789abcdef' for c in result)
class TestDownloadImg:
"""Test cases for download_img function"""
def test_empty_url_returns_empty_string(self):
"""Test that empty URL returns empty string"""
result = asyncio.run(download_img(""))
assert result == ""
def test_none_url_returns_empty_string(self):
"""Test that None URL returns empty string"""
result = asyncio.run(download_img(None))
assert result == ""
def test_loopback_url_blocked(self):
"""OAuth avatar fetch must not call loopback (SSRF regression)."""
result = asyncio.run(download_img("http://127.0.0.1/avatar.png"))
assert result == ""
def test_metadata_ip_blocked(self):
"""Link-local / cloud metadata ranges are non-global and must be rejected."""
result = asyncio.run(download_img("http://169.254.169.254/latest/meta-data/"))
assert result == ""
def test_disallowed_scheme_blocked(self):
result = asyncio.run(download_img("file:///etc/passwd"))
assert result == ""
def test_redirect_to_loopback_blocked(self):
"""Redirect from an allowed host to loopback must be rejected (SSRF)."""
from urllib.parse import urlparse
real_assert = ssrf_guard.assert_url_is_safe
def selective_assert(url: str, **kwargs):
host = urlparse(url).hostname or ""
if host in ("127.0.0.1", "localhost", "169.254.169.254"):
return real_assert(url, **kwargs)
return ("public-avatar.test", "8.8.8.8")
client = _FakeAsyncClient([_MockStreamResp(302, location="http://127.0.0.1/next")])
with (
patch.object(ssrf_guard, "assert_url_is_safe", side_effect=selective_assert),
_fake_httpx_sys_modules(client),
):
result = asyncio.run(download_img("http://public-avatar.test/start.png"))
assert result == ""
def test_redirect_too_many_hops_blocked(self):
"""Excessive redirect chains must return empty without hanging."""
import common.misc_utils as misc_utils
hops = [
_MockStreamResp(302, location="http://h.example/1"),
_MockStreamResp(302, location="http://h.example/2"),
_MockStreamResp(302, location="http://h.example/3"),
]
client = _FakeAsyncClient(hops)
with (
patch.object(misc_utils, "_OAUTH_AVATAR_MAX_REDIRECTS", 2),
patch.object(ssrf_guard, "assert_url_is_safe", return_value=("h.example", "8.8.8.8")),
_fake_httpx_sys_modules(client),
):
result = asyncio.run(misc_utils.download_img("http://h.example/start"))
assert result == ""
class TestHashStr2Int:
"""Test cases for hash_str2int function"""
def test_basic_hashing(self):
"""Test basic string hashing functionality"""
result = hash_str2int("hello")
assert isinstance(result, int)
assert 0 <= result < 10 ** 8
def test_default_mod_value(self):
"""Test that default mod value is 10^8"""
result = hash_str2int("test")
assert 0 <= result < 10 ** 8
def test_custom_mod_value(self):
"""Test with custom mod value"""
result = hash_str2int("test", mod=1000)
assert isinstance(result, int)
assert 0 <= result < 1000
def test_same_input_same_output(self):
"""Test that same input produces same output"""
result1 = hash_str2int("consistent")
result2 = hash_str2int("consistent")
result3 = hash_str2int("consistent")
assert result1 == result2 == result3
def test_different_input_different_output(self):
"""Test that different inputs produce different outputs (usually)"""
result1 = hash_str2int("hello")
result2 = hash_str2int("world")
result3 = hash_str2int("hello world")
# While hash collisions are possible, they're very unlikely for these inputs
results = [result1, result2, result3]
assert len(set(results)) == len(results)
def test_empty_string(self):
"""Test hashing empty string"""
result = hash_str2int("")
assert isinstance(result, int)
assert 0 <= result < 10 ** 8
def test_unicode_string(self):
"""Test hashing unicode strings"""
test_strings = [
"中文",
"🚀火箭",
"café",
"🎉",
"Hello 世界"
]
for test_str in test_strings:
result = hash_str2int(test_str)
assert isinstance(result, int)
assert 0 <= result < 10 ** 8
def test_special_characters(self):
"""Test hashing strings with special characters"""
test_strings = [
"hello@world.com",
"test#123",
"line\nwith\nnewlines",
"tab\tcharacter",
"space in string"
]
for test_str in test_strings:
result = hash_str2int(test_str)
assert isinstance(result, int)
assert 0 <= result < 10 ** 8
def test_large_string(self):
"""Test hashing large string"""
large_string = "x" * 10000
result = hash_str2int(large_string)
assert isinstance(result, int)
assert 0 <= result < 10 ** 8
def test_mod_value_1(self):
"""Test with mod value 1 (should always return 0)"""
result = hash_str2int("any string", mod=1)
assert result == 0
def test_mod_value_2(self):
"""Test with mod value 2 (should return 0 or 1)"""
result = hash_str2int("test", mod=2)
assert result in [0, 1]
def test_very_large_mod(self):
"""Test with very large mod value"""
result = hash_str2int("test", mod=10 ** 12)
assert isinstance(result, int)
assert 0 <= result < 10 ** 12
def test_hash_algorithm_sha1(self):
"""Test that SHA1 algorithm is used"""
test_string = "hello"
expected_hash = hashlib.sha1(test_string.encode("utf-8")).hexdigest()
expected_int = int(expected_hash, 16) % (10 ** 8)
result = hash_str2int(test_string)
assert result == expected_int
def test_utf8_encoding(self):
"""Test that UTF-8 encoding is used"""
# This should work without encoding errors
result = hash_str2int("café 🎉")
assert isinstance(result, int)
def test_range_with_different_mods(self):
"""Test that result is always in correct range for different mod values"""
test_cases = [
("test1", 100),
("test2", 1000),
("test3", 10000),
("test4", 999999),
]
for test_str, mod_val in test_cases:
result = hash_str2int(test_str, mod=mod_val)
assert 0 <= result < mod_val
def test_hexdigest_conversion(self):
"""Test the hexdigest to integer conversion"""
test_string = "hello"
hash_obj = hashlib.sha1(test_string.encode("utf-8"))
hex_digest = hash_obj.hexdigest()
expected_int = int(hex_digest, 16) % (10 ** 8)
result = hash_str2int(test_string)
assert result == expected_int
def test_consistent_with_direct_calculation(self):
"""Test that function matches direct hashlib usage"""
test_strings = ["a", "b", "abc", "hello world", "12345"]
for test_str in test_strings:
direct_result = int(hashlib.sha1(test_str.encode("utf-8")).hexdigest(), 16) % (10 ** 8)
function_result = hash_str2int(test_str)
assert function_result == direct_result
def test_numeric_strings(self):
"""Test hashing numeric strings"""
test_strings = ["123", "0", "999999", "3.14159", "-42"]
for test_str in test_strings:
result = hash_str2int(test_str)
assert isinstance(result, int)
assert 0 <= result < 10 ** 8
def test_whitespace_strings(self):
"""Test hashing strings with various whitespace"""
test_strings = [
" leading",
"trailing ",
" both ",
"\ttab",
"new\nline",
"\r\nwindows"
]
for test_str in test_strings:
result = hash_str2int(test_str)
assert isinstance(result, int)
assert 0 <= result < 10 ** 8
class TestConvertBytes:
"""Test suite for convert_bytes function"""
def test_zero_bytes(self):
"""Test that 0 bytes returns '0 B'"""
assert convert_bytes(0) == "0 B"
def test_single_byte(self):
"""Test single byte values"""
assert convert_bytes(1) == "1 B"
assert convert_bytes(999) == "999 B"
def test_kilobyte_range(self):
"""Test values in kilobyte range with different precisions"""
# Exactly 1 KB
assert convert_bytes(1024) == "1.00 KB"
# Values that should show 1 decimal place (10-99.9 range)
assert convert_bytes(15360) == "15.0 KB" # 15 KB exactly
assert convert_bytes(10752) == "10.5 KB" # 10.5 KB
# Values that should show 2 decimal places (1-9.99 range)
assert convert_bytes(2048) == "2.00 KB" # 2 KB exactly
assert convert_bytes(3072) == "3.00 KB" # 3 KB exactly
assert convert_bytes(5120) == "5.00 KB" # 5 KB exactly
def test_megabyte_range(self):
"""Test values in megabyte range"""
# Exactly 1 MB
assert convert_bytes(1048576) == "1.00 MB"
# Values with different precision requirements
assert convert_bytes(15728640) == "15.0 MB" # 15.0 MB
assert convert_bytes(11010048) == "10.5 MB" # 10.5 MB
def test_gigabyte_range(self):
"""Test values in gigabyte range"""
# Exactly 1 GB
assert convert_bytes(1073741824) == "1.00 GB"
# Large value that should show 0 decimal places
assert convert_bytes(3221225472) == "3.00 GB" # 3 GB exactly
def test_terabyte_range(self):
"""Test values in terabyte range"""
assert convert_bytes(1099511627776) == "1.00 TB" # 1 TB
def test_petabyte_range(self):
"""Test values in petabyte range"""
assert convert_bytes(1125899906842624) == "1.00 PB" # 1 PB
def test_boundary_values(self):
"""Test values at unit boundaries"""
# Just below 1 KB
assert convert_bytes(1023) == "1023 B"
# Just above 1 KB
assert convert_bytes(1025) == "1.00 KB"
# At 100 KB boundary (should switch to 0 decimal places)
assert convert_bytes(102400) == "100 KB"
assert convert_bytes(102300) == "99.9 KB"
def test_precision_transitions(self):
"""Test the precision formatting transitions"""
# Test transition from 2 decimal places to 1 decimal place
assert convert_bytes(9216) == "9.00 KB" # 9.00 KB (2 decimal places)
assert convert_bytes(10240) == "10.0 KB" # 10.0 KB (1 decimal place)
# Test transition from 1 decimal place to 0 decimal places
assert convert_bytes(102400) == "100 KB" # 100 KB (0 decimal places)
def test_large_values_no_overflow(self):
"""Test that very large values don't cause issues"""
# Very large value that should use PB
large_value = 10 * 1125899906842624 # 10 PB
assert "PB" in convert_bytes(large_value)
# Ensure we don't exceed available units
huge_value = 100 * 1125899906842624 # 100 PB (still within PB range)
assert "PB" in convert_bytes(huge_value)