Compare commits
1 Commits
main
...
fix/coderu
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e42cf3b2f |
1
.gitignore
vendored
1
.gitignore
vendored
@ -60,3 +60,4 @@ values-dev.yaml
|
||||
|
||||
*.tsbuildinfo
|
||||
|
||||
.coda/
|
||||
|
||||
@ -51,7 +51,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
|
||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
|
||||
@ -346,40 +345,43 @@ func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
|
||||
}
|
||||
|
||||
func initCodeRunner() coderunner.Runner {
|
||||
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
||||
case "sandbox":
|
||||
getAndSplit := func(key string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(v, ",")
|
||||
// 为了安全考虑,移除不安全的direct runner,强制使用sandbox
|
||||
getAndSplit := func(key string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
config := &sandbox.Config{
|
||||
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv),
|
||||
AllowRead: getAndSplit(consts.CodeRunnerAllowRead),
|
||||
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite),
|
||||
AllowNet: getAndSplit(consts.CodeRunnerAllowNet),
|
||||
AllowRun: getAndSplit(consts.CodeRunnerAllowRun),
|
||||
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI),
|
||||
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
||||
TimeoutSeconds: 0,
|
||||
MemoryLimitMB: 0,
|
||||
}
|
||||
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil {
|
||||
config.TimeoutSeconds = f
|
||||
} else {
|
||||
config.TimeoutSeconds = 60.0
|
||||
}
|
||||
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil {
|
||||
config.MemoryLimitMB = mem
|
||||
} else {
|
||||
config.MemoryLimitMB = 100
|
||||
}
|
||||
return sandbox.NewRunner(config)
|
||||
default:
|
||||
return direct.NewRunner()
|
||||
return strings.Split(v, ",")
|
||||
}
|
||||
|
||||
// 使用安全的默认配置
|
||||
config := &sandbox.Config{
|
||||
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), // 默认为空,禁止环境变量访问
|
||||
AllowRead: getAndSplit(consts.CodeRunnerAllowRead), // 默认为空,禁止文件读取
|
||||
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), // 默认为空,禁止文件写入
|
||||
AllowNet: getAndSplit(consts.CodeRunnerAllowNet), // 默认为空,禁止网络访问
|
||||
AllowRun: getAndSplit(consts.CodeRunnerAllowRun), // 默认为空,禁止运行外部程序
|
||||
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), // 默认为空,禁止FFI调用
|
||||
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
||||
TimeoutSeconds: 0,
|
||||
MemoryLimitMB: 0,
|
||||
}
|
||||
|
||||
// 设置安全的超时时间,最大30秒
|
||||
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil && f > 0 && f <= 30 {
|
||||
config.TimeoutSeconds = f
|
||||
} else {
|
||||
config.TimeoutSeconds = 30.0 // 默认30秒超时
|
||||
}
|
||||
|
||||
// 设置安全的内存限制,最大100MB
|
||||
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil && mem > 0 && mem <= 100 {
|
||||
config.MemoryLimitMB = mem
|
||||
} else {
|
||||
config.MemoryLimitMB = 100 // 默认100MB内存限制
|
||||
}
|
||||
|
||||
return sandbox.NewRunner(config)
|
||||
}
|
||||
|
||||
func initOCR() ocr.OCR {
|
||||
@ -798,4 +800,4 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
}
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -47,6 +48,7 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
node := mdParser.Parse(text.NewReader(b))
|
||||
cs := config.ChunkingStrategy
|
||||
ps := config.ParsingStrategy
|
||||
@ -101,13 +103,118 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
||||
return text
|
||||
}
|
||||
|
||||
// validateImageURL 验证图片URL的安全性
|
||||
validateImageURL := func(urlString string) error {
|
||||
parsedURL, err := url.Parse(urlString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 只允许HTTP/HTTPS
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
// 检查域名白名单
|
||||
allowedDomains := []string{
|
||||
"images.unsplash.com",
|
||||
"cdn.example.com",
|
||||
"github.com",
|
||||
"githubusercontent.com",
|
||||
// 可以根据需要添加其他受信任的域名
|
||||
}
|
||||
|
||||
hostname := parsedURL.Hostname()
|
||||
for _, domain := range allowedDomains {
|
||||
if hostname == domain || strings.HasSuffix(hostname, "."+domain) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("domain not allowed: %s", hostname)
|
||||
}
|
||||
|
||||
// isPrivateIPAddress 检查IP地址是否为私有地址
|
||||
isPrivateIPAddress := func(ip net.IP) bool {
|
||||
// 检查私有IP范围
|
||||
privateRanges := []struct {
|
||||
cidr string
|
||||
}{
|
||||
{"10.0.0.0/8"},
|
||||
{"172.16.0.0/12"},
|
||||
{"192.168.0.0/16"},
|
||||
{"127.0.0.0/8"},
|
||||
{"169.254.0.0/16"}, // 链路本地地址
|
||||
{"::1/128"}, // IPv6 loopback
|
||||
{"fc00::/7"}, // IPv6 私有地址
|
||||
}
|
||||
|
||||
for _, r := range privateRanges {
|
||||
_, cidr, _ := net.ParseCIDR(r.cidr)
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateIP 检查是否为私有IP地址
|
||||
isPrivateIP := func(host string) bool {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
// 可能是域名,需要解析
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return true // 解析失败,拒绝访问
|
||||
}
|
||||
|
||||
// 检查所有解析的IP
|
||||
for _, resolvedIP := range ips {
|
||||
if isPrivateIPAddress(resolvedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return isPrivateIPAddress(ip)
|
||||
}
|
||||
|
||||
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
// URL验证
|
||||
if err := validateImageURL(url); err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// 使用安全的HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 禁止访问私有IP
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isPrivateIP(host) {
|
||||
return nil, fmt.Errorf("access to private IP denied: %s", host)
|
||||
}
|
||||
|
||||
return (&net.Dialer{}).DialContext(ctx, network, addr)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// 添加安全头
|
||||
req.Header.Set("User-Agent", "CozeStudio/1.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download image: %w", err)
|
||||
@ -118,7 +225,11 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
||||
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
// 限制响应大小
|
||||
const maxImageSize = 10 * 1024 * 1024 // 10MB
|
||||
limitedReader := io.LimitReader(resp.Body, maxImageSize)
|
||||
|
||||
data, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read image content: %w", err)
|
||||
}
|
||||
|
||||
@ -646,15 +646,15 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques
|
||||
var processedParams []interface{}
|
||||
var err error
|
||||
|
||||
// Handle SQLType: if raw, do not process params
|
||||
// 禁用原始SQL执行以防止SQL注入攻击
|
||||
if req.SQLType == entity2.SQLType_Raw {
|
||||
processedSQL = req.SQL
|
||||
processedParams = nil
|
||||
} else {
|
||||
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("raw SQL execution is not allowed for security reasons")
|
||||
}
|
||||
|
||||
// 强制使用参数化查询
|
||||
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
||||
}
|
||||
|
||||
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
|
||||
@ -1011,4 +1011,4 @@ func (m *mysqlService) buildNestedConditions(condition *rdb.ComplexCondition) (s
|
||||
return whereClause.String(), values, nil
|
||||
}
|
||||
return "", values, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user