mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
use squid for ssrf
This commit is contained in:
@ -2,79 +2,17 @@
|
||||
Proxy requests to avoid SSRF
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.tools.errors import ToolSSRFError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_private_or_local_address(url: str) -> bool:
|
||||
"""
|
||||
Check if URL points to a private/local network address (SSRF protection).
|
||||
|
||||
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
|
||||
by detecting private IP addresses, localhost, and local network domains.
|
||||
|
||||
Args:
|
||||
url: The URL string to check
|
||||
|
||||
Returns:
|
||||
True if the URL points to a private/local address, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> is_private_or_local_address("http://localhost/api")
|
||||
True
|
||||
>>> is_private_or_local_address("http://192.168.1.1/api")
|
||||
True
|
||||
>>> is_private_or_local_address("https://example.com/api")
|
||||
False
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
hostname_lower = hostname.lower()
|
||||
|
||||
# Check for localhost variants
|
||||
if hostname_lower in ("localhost", "127.0.0.1", "::1"):
|
||||
return True
|
||||
|
||||
# Check for .local domains (link-local)
|
||||
if hostname_lower.endswith(".local"):
|
||||
return True
|
||||
|
||||
# Try to parse as IP address
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
|
||||
# Check if it's a private, loopback, or link-local address.
|
||||
# - Private: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7
|
||||
# - Loopback: 127.0.0.0/8, ::1
|
||||
# - Link-local: 169.254.0.0/16, fe80::/10
|
||||
return ip.is_private or ip.is_loopback or ip.is_link_local
|
||||
except ValueError:
|
||||
# Not a valid IP address, might be a domain name
|
||||
# Domain names could resolve to private IPs, but we only check the literal hostname here
|
||||
# For more thorough checks, DNS resolution would be needed (but adds latency)
|
||||
return False
|
||||
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
@ -156,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
# Check for SSRF protection by Squid proxy
|
||||
if response.status_code in (401, 403):
|
||||
# Check if this is a Squid SSRF rejection
|
||||
server_header = response.headers.get("server", "").lower()
|
||||
via_header = response.headers.get("via", "").lower()
|
||||
|
||||
# Squid typically identifies itself in Server or Via headers
|
||||
if "squid" in server_header or "squid" in via_header:
|
||||
raise ToolSSRFError(
|
||||
f"Access to '{url}' was blocked by SSRF protection. "
|
||||
f"The URL may point to a private or local network address. "
|
||||
)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
return response
|
||||
|
||||
@ -8,33 +8,13 @@ import httpx
|
||||
from flask import request
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.helper.ssrf_proxy import is_private_or_local_address
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def _validate_server_urls(servers: list[dict]) -> None:
|
||||
"""
|
||||
Validate server URLs to prevent SSRF attacks.
|
||||
|
||||
Args:
|
||||
servers: List of server dictionaries containing 'url' keys
|
||||
|
||||
Raises:
|
||||
ToolSSRFError: If any server URL points to a private or local network address
|
||||
"""
|
||||
for server in servers:
|
||||
server_url = server.get("url", "")
|
||||
if server_url and is_private_or_local_address(server_url):
|
||||
raise ToolSSRFError(
|
||||
f"Server URL '{server_url}' points to a private or local network address, "
|
||||
"which is not allowed for security reasons (SSRF protection)."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
@ -48,9 +28,6 @@ class ApiBasedToolSchemaParser:
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
# SSRF Protection: Validate all server URLs before processing
|
||||
ApiBasedToolSchemaParser._validate_server_urls(openapi["servers"])
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
@ -310,9 +287,6 @@ class ApiBasedToolSchemaParser:
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
# SSRF Protection: Validate all server URLs before processing
|
||||
ApiBasedToolSchemaParser._validate_server_urls(servers)
|
||||
|
||||
converted_openapi: dict[str, Any] = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
@ -386,13 +360,6 @@ class ApiBasedToolSchemaParser:
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# SSRF Protection: Validate API URL before making HTTP request
|
||||
if is_private_or_local_address(api_url):
|
||||
raise ToolSSRFError(
|
||||
f"API URL '{api_url}' points to a private or local network address, "
|
||||
"which is not allowed for security reasons (SSRF protection)."
|
||||
)
|
||||
|
||||
# get openapi yaml
|
||||
response = httpx.get(
|
||||
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
|
||||
@ -457,8 +424,6 @@ class ApiBasedToolSchemaParser:
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
except ToolSSRFError:
|
||||
raise
|
||||
|
||||
# openapi parse error, fallback to swagger
|
||||
try:
|
||||
@ -471,18 +436,12 @@ class ApiBasedToolSchemaParser:
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
except ToolSSRFError:
|
||||
# SSRF protection errors should be raised immediately, don't fallback
|
||||
raise
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
|
||||
except ToolSSRFError:
|
||||
# SSRF protection errors should be raised immediately, don't fallback
|
||||
raise
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
Reference in New Issue
Block a user