Compare commits

...

1 Commits

Author SHA1 Message Date
7e42cf3b2f fix: [Coda] 修复coderunner RCE/SSRF/SQL注入安全漏洞
(LogID: 2025091818571901007120218221212EA)

Co-Authored-By: Coda <coda@bytedance.com>
2025-09-18 19:29:50 +08:00
4 changed files with 160 additions and 46 deletions

1
.gitignore vendored
View File

@ -60,3 +60,4 @@ values-dev.yaml
*.tsbuildinfo
.coda/

View File

@ -51,7 +51,6 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
@ -346,40 +345,43 @@ func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
}
func initCodeRunner() coderunner.Runner {
switch typ := os.Getenv(consts.CodeRunnerType); typ {
case "sandbox":
getAndSplit := func(key string) []string {
v := os.Getenv(key)
if v == "" {
return nil
}
return strings.Split(v, ",")
// 为了安全考虑移除不安全的direct runner强制使用sandbox
getAndSplit := func(key string) []string {
v := os.Getenv(key)
if v == "" {
return nil
}
config := &sandbox.Config{
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv),
AllowRead: getAndSplit(consts.CodeRunnerAllowRead),
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite),
AllowNet: getAndSplit(consts.CodeRunnerAllowNet),
AllowRun: getAndSplit(consts.CodeRunnerAllowRun),
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI),
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
TimeoutSeconds: 0,
MemoryLimitMB: 0,
}
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil {
config.TimeoutSeconds = f
} else {
config.TimeoutSeconds = 60.0
}
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil {
config.MemoryLimitMB = mem
} else {
config.MemoryLimitMB = 100
}
return sandbox.NewRunner(config)
default:
return direct.NewRunner()
return strings.Split(v, ",")
}
// 使用安全的默认配置
config := &sandbox.Config{
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), // 默认为空,禁止环境变量访问
AllowRead: getAndSplit(consts.CodeRunnerAllowRead), // 默认为空,禁止文件读取
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), // 默认为空,禁止文件写入
AllowNet: getAndSplit(consts.CodeRunnerAllowNet), // 默认为空,禁止网络访问
AllowRun: getAndSplit(consts.CodeRunnerAllowRun), // 默认为空,禁止运行外部程序
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), // 默认为空禁止FFI调用
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
TimeoutSeconds: 0,
MemoryLimitMB: 0,
}
// 设置安全的超时时间最大30秒
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil && f > 0 && f <= 30 {
config.TimeoutSeconds = f
} else {
config.TimeoutSeconds = 30.0 // 默认30秒超时
}
// 设置安全的内存限制最大100MB
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil && mem > 0 && mem <= 100 {
config.MemoryLimitMB = mem
} else {
config.MemoryLimitMB = 100 // 默认100MB内存限制
}
return sandbox.NewRunner(config)
}
func initOCR() ocr.OCR {
@ -798,4 +800,4 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
}
return emb, nil
}
}

View File

@ -21,6 +21,7 @@ import (
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
@ -47,6 +48,7 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
return nil, err
}
node := mdParser.Parse(text.NewReader(b))
cs := config.ChunkingStrategy
ps := config.ParsingStrategy
@ -101,13 +103,118 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
return text
}
// validateImageURL 验证图片URL的安全性
validateImageURL := func(urlString string) error {
parsedURL, err := url.Parse(urlString)
if err != nil {
return err
}
// 只允许HTTP/HTTPS
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme)
}
// 检查域名白名单
allowedDomains := []string{
"images.unsplash.com",
"cdn.example.com",
"github.com",
"githubusercontent.com",
// 可以根据需要添加其他受信任的域名
}
hostname := parsedURL.Hostname()
for _, domain := range allowedDomains {
if hostname == domain || strings.HasSuffix(hostname, "."+domain) {
return nil
}
}
return fmt.Errorf("domain not allowed: %s", hostname)
}
// isPrivateIPAddress 检查IP地址是否为私有地址
isPrivateIPAddress := func(ip net.IP) bool {
// 检查私有IP范围
privateRanges := []struct {
cidr string
}{
{"10.0.0.0/8"},
{"172.16.0.0/12"},
{"192.168.0.0/16"},
{"127.0.0.0/8"},
{"169.254.0.0/16"}, // 链路本地地址
{"::1/128"}, // IPv6 loopback
{"fc00::/7"}, // IPv6 私有地址
}
for _, r := range privateRanges {
_, cidr, _ := net.ParseCIDR(r.cidr)
if cidr.Contains(ip) {
return true
}
}
return false
}
// isPrivateIP 检查是否为私有IP地址
isPrivateIP := func(host string) bool {
ip := net.ParseIP(host)
if ip == nil {
// 可能是域名,需要解析
ips, err := net.LookupIP(host)
if err != nil {
return true // 解析失败,拒绝访问
}
// 检查所有解析的IP
for _, resolvedIP := range ips {
if isPrivateIPAddress(resolvedIP) {
return true
}
}
return false
}
return isPrivateIPAddress(ip)
}
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
client := &http.Client{Timeout: 5 * time.Second}
// URL验证
if err := validateImageURL(url); err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
// 使用安全的HTTP客户端
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 禁止访问私有IP
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if isPrivateIP(host) {
return nil, fmt.Errorf("access to private IP denied: %s", host)
}
return (&net.Dialer{}).DialContext(ctx, network, addr)
},
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// 添加安全头
req.Header.Set("User-Agent", "CozeStudio/1.0")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download image: %w", err)
@ -118,7 +225,11 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
// 限制响应大小
const maxImageSize = 10 * 1024 * 1024 // 10MB
limitedReader := io.LimitReader(resp.Body, maxImageSize)
data, err := io.ReadAll(limitedReader)
if err != nil {
return nil, fmt.Errorf("failed to read image content: %w", err)
}

View File

@ -646,15 +646,15 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques
var processedParams []interface{}
var err error
// Handle SQLType: if raw, do not process params
// 禁用原始SQL执行以防止SQL注入攻击
if req.SQLType == entity2.SQLType_Raw {
processedSQL = req.SQL
processedParams = nil
} else {
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
if err != nil {
return nil, fmt.Errorf("failed to process parameters: %v", err)
}
return nil, fmt.Errorf("raw SQL execution is not allowed for security reasons")
}
// 强制使用参数化查询
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
if err != nil {
return nil, fmt.Errorf("failed to process parameters: %v", err)
}
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
@ -1011,4 +1011,4 @@ func (m *mysqlService) buildNestedConditions(condition *rdb.ComplexCondition) (s
return whereClause.String(), values, nil
}
return "", values, nil
}
}