add internal ip filter when parse tool schema

This commit is contained in:
Yansong Zhang
2025-12-12 11:24:25 +08:00
parent acdbcdb6f8
commit 3f3b9beeff
5 changed files with 424 additions and 4 deletions

View File

@ -2,8 +2,10 @@
Proxy requests to avoid SSRF
"""
import ipaddress
import logging
import time
from urllib.parse import urlparse
import httpx
@ -12,6 +14,76 @@ from core.helper.http_client_pooling import get_pooled_http_client
logger = logging.getLogger(__name__)
def is_private_or_local_address(url: str) -> bool:
"""
Check if URL points to a private/local network address (SSRF protection).
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
by detecting private IP addresses, localhost, and local network domains.
Args:
url: The URL string to check
Returns:
True if the URL points to a private/local address, False otherwise
Examples:
>>> is_private_or_local_address("http://localhost/api")
True
>>> is_private_or_local_address("http://192.168.1.1/api")
True
>>> is_private_or_local_address("https://example.com/api")
False
"""
if not url:
return False
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
hostname_lower = hostname.lower()
# Check for localhost variants
if hostname_lower in ("localhost", "127.0.0.1", "::1"):
return True
# Check for .local domains (link-local)
if hostname_lower.endswith(".local"):
return True
# Try to parse as IP address
try:
ip = ipaddress.ip_address(hostname)
# Check if it's a private IP (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 for IPv4)
# For IPv6: fc00::/7 (unique local addresses)
if ip.is_private:
return True
# Check if it's loopback (127.0.0.0/8 for IPv4, ::1 for IPv6)
if ip.is_loopback:
return True
# Check if it's link-local (169.254.0.0/16 for IPv4, fe80::/10 for IPv6)
if ip.is_link_local:
return True
return False
except ValueError:
# Not a valid IP address, might be a domain name
# Domain names could resolve to private IPs, but we only check the literal hostname here
# For more thorough checks, DNS resolution would be needed (but adds latency)
return False
except (ValueError, TypeError, AttributeError):
return False
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5

View File

@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
pass
class ToolSSRFError(ValueError):
pass
class ToolCredentialPolicyViolationError(ValueError):
pass

View File

@ -8,10 +8,11 @@ import httpx
from flask import request
from yaml import YAMLError, safe_load
from core.helper.ssrf_proxy import is_private_or_local_address
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError
class ApiBasedToolSchemaParser:
@ -28,6 +29,15 @@ class ApiBasedToolSchemaParser:
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
# SSRF Protection: Validate all server URLs before processing
for server in openapi["servers"]:
server_url_to_check = server.get("url", "")
if server_url_to_check and is_private_or_local_address(server_url_to_check):
raise ToolSSRFError(
f"Server URL '{server_url_to_check}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
@ -287,6 +297,15 @@ class ApiBasedToolSchemaParser:
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
# SSRF Protection: Validate all server URLs before processing
for server in servers:
server_url_to_check = server.get("url", "")
if server_url_to_check and is_private_or_local_address(server_url_to_check):
raise ToolSSRFError(
f"Server URL '{server_url_to_check}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
converted_openapi: dict[str, Any] = {
"openapi": "3.0.0",
"info": {
@ -360,6 +379,13 @@ class ApiBasedToolSchemaParser:
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# SSRF Protection: Validate API URL before making HTTP request
if is_private_or_local_address(api_url):
raise ToolSSRFError(
f"API URL '{api_url}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
# get openapi yaml
response = httpx.get(
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
@ -424,8 +450,10 @@ class ApiBasedToolSchemaParser:
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
except ToolSSRFError:
raise
# openai parse error, fallback to swagger
# openapi parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
@ -436,13 +464,18 @@ class ApiBasedToolSchemaParser:
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
except ToolSSRFError:
# SSRF protection errors should be raised immediately, don't fallback
raise
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
except ToolSSRFError:
# SSRF protection errors should be raised immediately, don't fallback
raise
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e