Compare commits
1 Commits
feat/plugi
...
fix/coderu
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e42cf3b2f |
1
.gitignore
vendored
1
.gitignore
vendored
@ -60,3 +60,4 @@ values-dev.yaml
|
|||||||
|
|
||||||
*.tsbuildinfo
|
*.tsbuildinfo
|
||||||
|
|
||||||
|
.coda/
|
||||||
|
|||||||
@ -51,7 +51,6 @@ import (
|
|||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
|
||||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
|
||||||
@ -346,40 +345,43 @@ func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func initCodeRunner() coderunner.Runner {
|
func initCodeRunner() coderunner.Runner {
|
||||||
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
// 为了安全考虑,移除不安全的direct runner,强制使用sandbox
|
||||||
case "sandbox":
|
getAndSplit := func(key string) []string {
|
||||||
getAndSplit := func(key string) []string {
|
v := os.Getenv(key)
|
||||||
v := os.Getenv(key)
|
if v == "" {
|
||||||
if v == "" {
|
return nil
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return strings.Split(v, ",")
|
|
||||||
}
|
}
|
||||||
config := &sandbox.Config{
|
return strings.Split(v, ",")
|
||||||
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv),
|
|
||||||
AllowRead: getAndSplit(consts.CodeRunnerAllowRead),
|
|
||||||
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite),
|
|
||||||
AllowNet: getAndSplit(consts.CodeRunnerAllowNet),
|
|
||||||
AllowRun: getAndSplit(consts.CodeRunnerAllowRun),
|
|
||||||
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI),
|
|
||||||
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
|
||||||
TimeoutSeconds: 0,
|
|
||||||
MemoryLimitMB: 0,
|
|
||||||
}
|
|
||||||
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil {
|
|
||||||
config.TimeoutSeconds = f
|
|
||||||
} else {
|
|
||||||
config.TimeoutSeconds = 60.0
|
|
||||||
}
|
|
||||||
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil {
|
|
||||||
config.MemoryLimitMB = mem
|
|
||||||
} else {
|
|
||||||
config.MemoryLimitMB = 100
|
|
||||||
}
|
|
||||||
return sandbox.NewRunner(config)
|
|
||||||
default:
|
|
||||||
return direct.NewRunner()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 使用安全的默认配置
|
||||||
|
config := &sandbox.Config{
|
||||||
|
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), // 默认为空,禁止环境变量访问
|
||||||
|
AllowRead: getAndSplit(consts.CodeRunnerAllowRead), // 默认为空,禁止文件读取
|
||||||
|
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), // 默认为空,禁止文件写入
|
||||||
|
AllowNet: getAndSplit(consts.CodeRunnerAllowNet), // 默认为空,禁止网络访问
|
||||||
|
AllowRun: getAndSplit(consts.CodeRunnerAllowRun), // 默认为空,禁止运行外部程序
|
||||||
|
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), // 默认为空,禁止FFI调用
|
||||||
|
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
|
||||||
|
TimeoutSeconds: 0,
|
||||||
|
MemoryLimitMB: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置安全的超时时间,最大30秒
|
||||||
|
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil && f > 0 && f <= 30 {
|
||||||
|
config.TimeoutSeconds = f
|
||||||
|
} else {
|
||||||
|
config.TimeoutSeconds = 30.0 // 默认30秒超时
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置安全的内存限制,最大100MB
|
||||||
|
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil && mem > 0 && mem <= 100 {
|
||||||
|
config.MemoryLimitMB = mem
|
||||||
|
} else {
|
||||||
|
config.MemoryLimitMB = 100 // 默认100MB内存限制
|
||||||
|
}
|
||||||
|
|
||||||
|
return sandbox.NewRunner(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func initOCR() ocr.OCR {
|
func initOCR() ocr.OCR {
|
||||||
@ -798,4 +800,4 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return emb, nil
|
return emb, nil
|
||||||
}
|
}
|
||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@ -47,6 +48,7 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
node := mdParser.Parse(text.NewReader(b))
|
node := mdParser.Parse(text.NewReader(b))
|
||||||
cs := config.ChunkingStrategy
|
cs := config.ChunkingStrategy
|
||||||
ps := config.ParsingStrategy
|
ps := config.ParsingStrategy
|
||||||
@ -101,13 +103,118 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
|||||||
return text
|
return text
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateImageURL 验证图片URL的安全性
|
||||||
|
validateImageURL := func(urlString string) error {
|
||||||
|
parsedURL, err := url.Parse(urlString)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只允许HTTP/HTTPS
|
||||||
|
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||||
|
return fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查域名白名单
|
||||||
|
allowedDomains := []string{
|
||||||
|
"images.unsplash.com",
|
||||||
|
"cdn.example.com",
|
||||||
|
"github.com",
|
||||||
|
"githubusercontent.com",
|
||||||
|
// 可以根据需要添加其他受信任的域名
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := parsedURL.Hostname()
|
||||||
|
for _, domain := range allowedDomains {
|
||||||
|
if hostname == domain || strings.HasSuffix(hostname, "."+domain) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("domain not allowed: %s", hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrivateIPAddress 检查IP地址是否为私有地址
|
||||||
|
isPrivateIPAddress := func(ip net.IP) bool {
|
||||||
|
// 检查私有IP范围
|
||||||
|
privateRanges := []struct {
|
||||||
|
cidr string
|
||||||
|
}{
|
||||||
|
{"10.0.0.0/8"},
|
||||||
|
{"172.16.0.0/12"},
|
||||||
|
{"192.168.0.0/16"},
|
||||||
|
{"127.0.0.0/8"},
|
||||||
|
{"169.254.0.0/16"}, // 链路本地地址
|
||||||
|
{"::1/128"}, // IPv6 loopback
|
||||||
|
{"fc00::/7"}, // IPv6 私有地址
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range privateRanges {
|
||||||
|
_, cidr, _ := net.ParseCIDR(r.cidr)
|
||||||
|
if cidr.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrivateIP 检查是否为私有IP地址
|
||||||
|
isPrivateIP := func(host string) bool {
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip == nil {
|
||||||
|
// 可能是域名,需要解析
|
||||||
|
ips, err := net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return true // 解析失败,拒绝访问
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查所有解析的IP
|
||||||
|
for _, resolvedIP := range ips {
|
||||||
|
if isPrivateIPAddress(resolvedIP) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return isPrivateIPAddress(ip)
|
||||||
|
}
|
||||||
|
|
||||||
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
|
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
|
||||||
client := &http.Client{Timeout: 5 * time.Second}
|
// URL验证
|
||||||
|
if err := validateImageURL(url); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用安全的HTTP客户端
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
// 禁止访问私有IP
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isPrivateIP(host) {
|
||||||
|
return nil, fmt.Errorf("access to private IP denied: %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (&net.Dialer{}).DialContext(ctx, network, addr)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 添加安全头
|
||||||
|
req.Header.Set("User-Agent", "CozeStudio/1.0")
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to download image: %w", err)
|
return nil, fmt.Errorf("failed to download image: %w", err)
|
||||||
@ -118,7 +225,11 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
|
|||||||
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
|
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
// 限制响应大小
|
||||||
|
const maxImageSize = 10 * 1024 * 1024 // 10MB
|
||||||
|
limitedReader := io.LimitReader(resp.Body, maxImageSize)
|
||||||
|
|
||||||
|
data, err := io.ReadAll(limitedReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read image content: %w", err)
|
return nil, fmt.Errorf("failed to read image content: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -646,15 +646,15 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques
|
|||||||
var processedParams []interface{}
|
var processedParams []interface{}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Handle SQLType: if raw, do not process params
|
// 禁用原始SQL执行以防止SQL注入攻击
|
||||||
if req.SQLType == entity2.SQLType_Raw {
|
if req.SQLType == entity2.SQLType_Raw {
|
||||||
processedSQL = req.SQL
|
return nil, fmt.Errorf("raw SQL execution is not allowed for security reasons")
|
||||||
processedParams = nil
|
}
|
||||||
} else {
|
|
||||||
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
// 强制使用参数化查询
|
||||||
if err != nil {
|
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
||||||
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
if err != nil {
|
||||||
}
|
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
|
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
|
||||||
@ -1011,4 +1011,4 @@ func (m *mysqlService) buildNestedConditions(condition *rdb.ComplexCondition) (s
|
|||||||
return whereClause.String(), values, nil
|
return whereClause.String(), values, nil
|
||||||
}
|
}
|
||||||
return "", values, nil
|
return "", values, nil
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user