mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
add internal ip filter when parse tool schema
This commit is contained in:
@ -2,8 +2,10 @@
|
|||||||
Proxy requests to avoid SSRF
|
Proxy requests to avoid SSRF
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@ -12,6 +14,76 @@ from core.helper.http_client_pooling import get_pooled_http_client
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_private_or_local_address(url: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if URL points to a private/local network address (SSRF protection).
|
||||||
|
|
||||||
|
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
|
||||||
|
by detecting private IP addresses, localhost, and local network domains.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL string to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the URL points to a private/local address, False otherwise
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> is_private_or_local_address("http://localhost/api")
|
||||||
|
True
|
||||||
|
>>> is_private_or_local_address("http://192.168.1.1/api")
|
||||||
|
True
|
||||||
|
>>> is_private_or_local_address("https://example.com/api")
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
hostname = parsed.hostname
|
||||||
|
|
||||||
|
if not hostname:
|
||||||
|
return False
|
||||||
|
|
||||||
|
hostname_lower = hostname.lower()
|
||||||
|
|
||||||
|
# Check for localhost variants
|
||||||
|
if hostname_lower in ("localhost", "127.0.0.1", "::1"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for .local domains (link-local)
|
||||||
|
if hostname_lower.endswith(".local"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Try to parse as IP address
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
|
||||||
|
# Check if it's a private IP (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 for IPv4)
|
||||||
|
# For IPv6: fc00::/7 (unique local addresses)
|
||||||
|
if ip.is_private:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if it's loopback (127.0.0.0/8 for IPv4, ::1 for IPv6)
|
||||||
|
if ip.is_loopback:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if it's link-local (169.254.0.0/16 for IPv4, fe80::/10 for IPv6)
|
||||||
|
if ip.is_link_local:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except ValueError:
|
||||||
|
# Not a valid IP address, might be a domain name
|
||||||
|
# Domain names could resolve to private IPs, but we only check the literal hostname here
|
||||||
|
# For more thorough checks, DNS resolution would be needed (but adds latency)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except (ValueError, TypeError, AttributeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||||
|
|
||||||
BACKOFF_FACTOR = 0.5
|
BACKOFF_FACTOR = 0.5
|
||||||
|
|||||||
@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSSRFError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolCredentialPolicyViolationError(ValueError):
|
class ToolCredentialPolicyViolationError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -8,10 +8,11 @@ import httpx
|
|||||||
from flask import request
|
from flask import request
|
||||||
from yaml import YAMLError, safe_load
|
from yaml import YAMLError, safe_load
|
||||||
|
|
||||||
|
from core.helper.ssrf_proxy import is_private_or_local_address
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError
|
||||||
|
|
||||||
|
|
||||||
class ApiBasedToolSchemaParser:
|
class ApiBasedToolSchemaParser:
|
||||||
@ -28,6 +29,15 @@ class ApiBasedToolSchemaParser:
|
|||||||
if len(openapi["servers"]) == 0:
|
if len(openapi["servers"]) == 0:
|
||||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||||
|
|
||||||
|
# SSRF Protection: Validate all server URLs before processing
|
||||||
|
for server in openapi["servers"]:
|
||||||
|
server_url_to_check = server.get("url", "")
|
||||||
|
if server_url_to_check and is_private_or_local_address(server_url_to_check):
|
||||||
|
raise ToolSSRFError(
|
||||||
|
f"Server URL '{server_url_to_check}' points to a private or local network address, "
|
||||||
|
"which is not allowed for security reasons (SSRF protection)."
|
||||||
|
)
|
||||||
|
|
||||||
server_url = openapi["servers"][0]["url"]
|
server_url = openapi["servers"][0]["url"]
|
||||||
request_env = request.headers.get("X-Request-Env")
|
request_env = request.headers.get("X-Request-Env")
|
||||||
if request_env:
|
if request_env:
|
||||||
@ -287,6 +297,15 @@ class ApiBasedToolSchemaParser:
|
|||||||
if len(servers) == 0:
|
if len(servers) == 0:
|
||||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||||
|
|
||||||
|
# SSRF Protection: Validate all server URLs before processing
|
||||||
|
for server in servers:
|
||||||
|
server_url_to_check = server.get("url", "")
|
||||||
|
if server_url_to_check and is_private_or_local_address(server_url_to_check):
|
||||||
|
raise ToolSSRFError(
|
||||||
|
f"Server URL '{server_url_to_check}' points to a private or local network address, "
|
||||||
|
"which is not allowed for security reasons (SSRF protection)."
|
||||||
|
)
|
||||||
|
|
||||||
converted_openapi: dict[str, Any] = {
|
converted_openapi: dict[str, Any] = {
|
||||||
"openapi": "3.0.0",
|
"openapi": "3.0.0",
|
||||||
"info": {
|
"info": {
|
||||||
@ -360,6 +379,13 @@ class ApiBasedToolSchemaParser:
|
|||||||
if api_type != "openapi":
|
if api_type != "openapi":
|
||||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||||
|
|
||||||
|
# SSRF Protection: Validate API URL before making HTTP request
|
||||||
|
if is_private_or_local_address(api_url):
|
||||||
|
raise ToolSSRFError(
|
||||||
|
f"API URL '{api_url}' points to a private or local network address, "
|
||||||
|
"which is not allowed for security reasons (SSRF protection)."
|
||||||
|
)
|
||||||
|
|
||||||
# get openapi yaml
|
# get openapi yaml
|
||||||
response = httpx.get(
|
response = httpx.get(
|
||||||
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
|
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
|
||||||
@ -424,8 +450,10 @@ class ApiBasedToolSchemaParser:
|
|||||||
return openapi, schema_type
|
return openapi, schema_type
|
||||||
except ToolApiSchemaError as e:
|
except ToolApiSchemaError as e:
|
||||||
openapi_error = e
|
openapi_error = e
|
||||||
|
except ToolSSRFError:
|
||||||
|
raise
|
||||||
|
|
||||||
# openai parse error, fallback to swagger
|
# openapi parse error, fallback to swagger
|
||||||
try:
|
try:
|
||||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||||
loaded_content, extra_info=extra_info, warning=warning
|
loaded_content, extra_info=extra_info, warning=warning
|
||||||
@ -436,13 +464,18 @@ class ApiBasedToolSchemaParser:
|
|||||||
), schema_type
|
), schema_type
|
||||||
except ToolApiSchemaError as e:
|
except ToolApiSchemaError as e:
|
||||||
swagger_error = e
|
swagger_error = e
|
||||||
|
except ToolSSRFError:
|
||||||
|
# SSRF protection errors should be raised immediately, don't fallback
|
||||||
|
raise
|
||||||
# swagger parse error, fallback to openai plugin
|
# swagger parse error, fallback to openai plugin
|
||||||
try:
|
try:
|
||||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||||
)
|
)
|
||||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
|
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
|
||||||
|
except ToolSSRFError:
|
||||||
|
# SSRF protection errors should be raised immediately, don't fallback
|
||||||
|
raise
|
||||||
except ToolNotSupportedError as e:
|
except ToolNotSupportedError as e:
|
||||||
# maybe it's not plugin at all
|
# maybe it's not plugin at all
|
||||||
openapi_plugin_error = e
|
openapi_plugin_error = e
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request
|
||||||
|
|
||||||
|
|
||||||
@patch("httpx.Client.request")
|
@patch("httpx.Client.request")
|
||||||
@ -50,3 +50,86 @@ def test_retry_logic_success(mock_request):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||||
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsPrivateOrLocalAddress:
|
||||||
|
"""Test cases for SSRF protection function."""
|
||||||
|
|
||||||
|
def test_localhost_variants(self):
|
||||||
|
"""Test that localhost variants are detected as private."""
|
||||||
|
assert is_private_or_local_address("http://localhost/api") is True
|
||||||
|
assert is_private_or_local_address("http://127.0.0.1/api") is True
|
||||||
|
assert is_private_or_local_address("http://[::1]/api") is True
|
||||||
|
assert is_private_or_local_address("https://localhost:8080/") is True
|
||||||
|
|
||||||
|
def test_private_ipv4_ranges(self):
|
||||||
|
"""Test that private IPv4 ranges are detected."""
|
||||||
|
# 10.0.0.0/8
|
||||||
|
assert is_private_or_local_address("http://10.0.0.1/api") is True
|
||||||
|
assert is_private_or_local_address("http://10.255.255.255/api") is True
|
||||||
|
|
||||||
|
# 172.16.0.0/12
|
||||||
|
assert is_private_or_local_address("http://172.16.0.1/api") is True
|
||||||
|
assert is_private_or_local_address("http://172.31.255.255/api") is True
|
||||||
|
|
||||||
|
# 192.168.0.0/16
|
||||||
|
assert is_private_or_local_address("http://192.168.0.1/api") is True
|
||||||
|
assert is_private_or_local_address("http://192.168.255.255/api") is True
|
||||||
|
|
||||||
|
# 169.254.0.0/16 (link-local)
|
||||||
|
assert is_private_or_local_address("http://169.254.1.1/api") is True
|
||||||
|
|
||||||
|
def test_local_domains(self):
|
||||||
|
"""Test that .local domains are detected as private."""
|
||||||
|
assert is_private_or_local_address("http://myserver.local/api") is True
|
||||||
|
assert is_private_or_local_address("https://test.local:8080/") is True
|
||||||
|
|
||||||
|
def test_public_addresses(self):
|
||||||
|
"""Test that public addresses are not detected as private."""
|
||||||
|
assert is_private_or_local_address("http://example.com/api") is False
|
||||||
|
assert is_private_or_local_address("https://api.openai.com/v1") is False
|
||||||
|
assert is_private_or_local_address("http://8.8.8.8/") is False
|
||||||
|
assert is_private_or_local_address("https://1.1.1.1/") is False
|
||||||
|
assert is_private_or_local_address("http://93.184.216.34/") is False
|
||||||
|
|
||||||
|
def test_edge_cases(self):
|
||||||
|
"""Test edge cases and invalid inputs."""
|
||||||
|
# Empty or None
|
||||||
|
assert is_private_or_local_address("") is False
|
||||||
|
assert is_private_or_local_address(None) is False
|
||||||
|
|
||||||
|
# Invalid URLs
|
||||||
|
assert is_private_or_local_address("not-a-url") is False
|
||||||
|
assert is_private_or_local_address("://invalid") is False
|
||||||
|
|
||||||
|
def test_ipv6_private_ranges(self):
|
||||||
|
"""Test that private IPv6 ranges are detected."""
|
||||||
|
# IPv6 loopback
|
||||||
|
assert is_private_or_local_address("http://[::1]/api") is True
|
||||||
|
|
||||||
|
# IPv6 link-local (fe80::/10)
|
||||||
|
assert is_private_or_local_address("http://[fe80::1]/api") is True
|
||||||
|
|
||||||
|
# IPv6 unique local (fc00::/7)
|
||||||
|
assert is_private_or_local_address("http://[fc00::1]/api") is True
|
||||||
|
assert is_private_or_local_address("http://[fd00::1]/api") is True
|
||||||
|
|
||||||
|
def test_public_ipv6(self):
|
||||||
|
"""Test that public IPv6 addresses are not detected as private."""
|
||||||
|
# Public IPv6 addresses (real examples)
|
||||||
|
# Google Public DNS IPv6
|
||||||
|
assert is_private_or_local_address("http://[2001:4860:4860::8888]/api") is False
|
||||||
|
# Cloudflare DNS IPv6
|
||||||
|
assert is_private_or_local_address("http://[2606:4700:4700::1111]/api") is False
|
||||||
|
|
||||||
|
def test_url_with_ports(self):
|
||||||
|
"""Test URLs with custom ports."""
|
||||||
|
assert is_private_or_local_address("http://localhost:8080/api") is True
|
||||||
|
assert is_private_or_local_address("http://192.168.1.1:3000/") is True
|
||||||
|
assert is_private_or_local_address("https://example.com:443/api") is False
|
||||||
|
|
||||||
|
def test_url_schemes(self):
|
||||||
|
"""Test different URL schemes."""
|
||||||
|
assert is_private_or_local_address("https://127.0.0.1/api") is True
|
||||||
|
assert is_private_or_local_address("http://127.0.0.1/api") is True
|
||||||
|
assert is_private_or_local_address("https://example.com/api") is False
|
||||||
|
|||||||
228
api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py
Normal file
228
api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
"""Unit tests for SSRF protection in API schema parser."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from core.tools.errors import ToolSSRFError
|
||||||
|
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flask_app():
|
||||||
|
"""Create a Flask app for testing."""
|
||||||
|
app = Flask(__name__)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiSchemaParserSSRF:
|
||||||
|
"""Test SSRF protection in API schema parser."""
|
||||||
|
|
||||||
|
def test_openapi_with_private_ip_blocked(self, flask_app):
|
||||||
|
"""Test that OpenAPI schema with private IP is blocked."""
|
||||||
|
openapi_schema = """
|
||||||
|
openapi: 3.0.0
|
||||||
|
info:
|
||||||
|
title: Test API
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: http://192.168.1.1/api
|
||||||
|
paths:
|
||||||
|
/test:
|
||||||
|
get:
|
||||||
|
summary: Test endpoint
|
||||||
|
operationId: testGet
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Success
|
||||||
|
"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
with pytest.raises(ToolSSRFError) as exc_info:
|
||||||
|
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
|
||||||
|
|
||||||
|
assert "192.168.1.1" in str(exc_info.value)
|
||||||
|
assert "private or local network address" in str(exc_info.value)
|
||||||
|
assert "SSRF protection" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_openapi_with_localhost_blocked(self, flask_app):
|
||||||
|
"""Test that OpenAPI schema with localhost is blocked."""
|
||||||
|
openapi_schema = """
|
||||||
|
openapi: 3.0.0
|
||||||
|
info:
|
||||||
|
title: Test API
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: http://localhost:8080/api
|
||||||
|
paths:
|
||||||
|
/test:
|
||||||
|
get:
|
||||||
|
summary: Test endpoint
|
||||||
|
operationId: testGet
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Success
|
||||||
|
"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
with pytest.raises(ToolSSRFError) as exc_info:
|
||||||
|
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
|
||||||
|
|
||||||
|
assert "localhost" in str(exc_info.value)
|
||||||
|
assert "SSRF protection" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_openapi_with_local_domain_blocked(self, flask_app):
|
||||||
|
"""Test that OpenAPI schema with .local domain is blocked."""
|
||||||
|
openapi_schema = """
|
||||||
|
openapi: 3.0.0
|
||||||
|
info:
|
||||||
|
title: Test API
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: http://myserver.local/api
|
||||||
|
paths:
|
||||||
|
/test:
|
||||||
|
get:
|
||||||
|
summary: Test endpoint
|
||||||
|
operationId: testGet
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Success
|
||||||
|
"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
with pytest.raises(ToolSSRFError) as exc_info:
|
||||||
|
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
|
||||||
|
|
||||||
|
assert "myserver.local" in str(exc_info.value)
|
||||||
|
assert "SSRF protection" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_openapi_with_10_network_blocked(self, flask_app):
|
||||||
|
"""Test that OpenAPI schema with 10.x.x.x network is blocked."""
|
||||||
|
openapi_schema = """
|
||||||
|
openapi: 3.0.0
|
||||||
|
info:
|
||||||
|
title: Test API
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: http://10.0.0.5/api
|
||||||
|
paths:
|
||||||
|
/test:
|
||||||
|
get:
|
||||||
|
summary: Test endpoint
|
||||||
|
operationId: testGet
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Success
|
||||||
|
"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
with pytest.raises(ToolSSRFError) as exc_info:
|
||||||
|
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
|
||||||
|
|
||||||
|
assert "10.0.0.5" in str(exc_info.value)
|
||||||
|
assert "SSRF protection" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_openapi_with_public_url_allowed(self, flask_app):
|
||||||
|
"""Test that OpenAPI schema with public URL is allowed."""
|
||||||
|
openapi_schema = """
|
||||||
|
openapi: 3.0.0
|
||||||
|
info:
|
||||||
|
title: Test API
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: https://api.example.com/v1
|
||||||
|
paths:
|
||||||
|
/test:
|
||||||
|
get:
|
||||||
|
summary: Test endpoint
|
||||||
|
operationId: testGet
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Success
|
||||||
|
"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
# Should not raise any exception
|
||||||
|
result, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_swagger_with_private_ip_blocked(self, flask_app):
|
||||||
|
"""Test that Swagger schema with private IP is blocked."""
|
||||||
|
swagger_schema = """
|
||||||
|
openapi: 3.0.0
|
||||||
|
info:
|
||||||
|
title: Test API
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: http://172.16.0.1/api
|
||||||
|
paths:
|
||||||
|
/test:
|
||||||
|
get:
|
||||||
|
summary: Test endpoint
|
||||||
|
operationId: testGet
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Success
|
||||||
|
"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
with pytest.raises(ToolSSRFError) as exc_info:
|
||||||
|
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(swagger_schema)
|
||||||
|
|
||||||
|
assert "172.16.0.1" in str(exc_info.value)
|
||||||
|
assert "SSRF protection" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_openapi_with_multiple_servers_one_private(self, flask_app):
|
||||||
|
"""Test that OpenAPI with multiple servers including one private is blocked."""
|
||||||
|
openapi_schema = """
|
||||||
|
openapi: 3.0.0
|
||||||
|
info:
|
||||||
|
title: Test API
|
||||||
|
version: 1.0.0
|
||||||
|
servers:
|
||||||
|
- url: https://api.example.com/v1
|
||||||
|
- url: http://192.168.1.100/api
|
||||||
|
paths:
|
||||||
|
/test:
|
||||||
|
get:
|
||||||
|
summary: Test endpoint
|
||||||
|
operationId: testGet
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Success
|
||||||
|
"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
with pytest.raises(ToolSSRFError) as exc_info:
|
||||||
|
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
|
||||||
|
|
||||||
|
assert "192.168.1.100" in str(exc_info.value)
|
||||||
|
assert "SSRF protection" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_openapi_json_format_with_private_ip_blocked(self, flask_app):
|
||||||
|
"""Test that JSON format OpenAPI schema with private IP is blocked."""
|
||||||
|
openapi_json = """{
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {
|
||||||
|
"title": "Test API",
|
||||||
|
"version": "1.0.0"
|
||||||
|
},
|
||||||
|
"servers": [
|
||||||
|
{
|
||||||
|
"url": "http://127.0.0.1:8080/api"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"paths": {
|
||||||
|
"/test": {
|
||||||
|
"get": {
|
||||||
|
"summary": "Test endpoint",
|
||||||
|
"operationId": "testGet",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"""
|
||||||
|
with flask_app.test_request_context():
|
||||||
|
with pytest.raises(ToolSSRFError) as exc_info:
|
||||||
|
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_json)
|
||||||
|
|
||||||
|
assert "127.0.0.1" in str(exc_info.value)
|
||||||
|
assert "SSRF protection" in str(exc_info.value)
|
||||||
Reference in New Issue
Block a user