Files
coze-studio/backend/infra/impl/document/parser/builtin/parse_markdown.go
duanlijun 7e42cf3b2f fix: [Coda] 修复coderunner RCE/SSRF/SQL注入安全漏洞
(LogID: 2025091818571901007120218221212EA)

Co-Authored-By: Coda <coda@bytedance.com>
2025-09-18 19:29:50 +08:00

333 lines
8.4 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/text"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR) ParseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
options := parser.GetCommonOptions(&parser.Options{}, opts...)
mdParser := goldmark.DefaultParser()
b, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
node := mdParser.Parse(text.NewReader(b))
cs := config.ChunkingStrategy
ps := config.ParsingStrategy
if cs.ChunkType != contract.ChunkTypeCustom && cs.ChunkType != contract.ChunkTypeDefault {
return nil, fmt.Errorf("[ParseMarkdown] chunk type not support, chunk type=%d", cs.ChunkType)
}
var (
last *schema.Document
emptySlice bool
)
addSliceContent := func(content string) {
emptySlice = false
last.Content += content
}
newSlice := func(needOverlap bool) {
last = &schema.Document{
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
last.MetaData[k] = v
}
if needOverlap && cs.Overlap > 0 && len(docs) > 0 {
overlap := getOverlap([]rune(docs[len(docs)-1].Content), cs.Overlap, cs.ChunkSize)
addSliceContent(string(overlap))
}
emptySlice = true
}
pushSlice := func() {
if !emptySlice && last.Content != "" {
docs = append(docs, last)
newSlice(true)
}
}
trim := func(text string) string {
if cs.TrimURLAndEmail {
text = urlRegex.ReplaceAllString(text, "")
text = emailRegex.ReplaceAllString(text, "")
}
if cs.TrimSpace {
text = strings.TrimSpace(text)
text = spaceRegex.ReplaceAllString(text, " ")
}
return text
}
// validateImageURL 验证图片URL的安全性
validateImageURL := func(urlString string) error {
parsedURL, err := url.Parse(urlString)
if err != nil {
return err
}
// 只允许HTTP/HTTPS
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme)
}
// 检查域名白名单
allowedDomains := []string{
"images.unsplash.com",
"cdn.example.com",
"github.com",
"githubusercontent.com",
// 可以根据需要添加其他受信任的域名
}
hostname := parsedURL.Hostname()
for _, domain := range allowedDomains {
if hostname == domain || strings.HasSuffix(hostname, "."+domain) {
return nil
}
}
return fmt.Errorf("domain not allowed: %s", hostname)
}
// isPrivateIPAddress 检查IP地址是否为私有地址
isPrivateIPAddress := func(ip net.IP) bool {
// 检查私有IP范围
privateRanges := []struct {
cidr string
}{
{"10.0.0.0/8"},
{"172.16.0.0/12"},
{"192.168.0.0/16"},
{"127.0.0.0/8"},
{"169.254.0.0/16"}, // 链路本地地址
{"::1/128"}, // IPv6 loopback
{"fc00::/7"}, // IPv6 私有地址
}
for _, r := range privateRanges {
_, cidr, _ := net.ParseCIDR(r.cidr)
if cidr.Contains(ip) {
return true
}
}
return false
}
// isPrivateIP 检查是否为私有IP地址
isPrivateIP := func(host string) bool {
ip := net.ParseIP(host)
if ip == nil {
// 可能是域名,需要解析
ips, err := net.LookupIP(host)
if err != nil {
return true // 解析失败,拒绝访问
}
// 检查所有解析的IP
for _, resolvedIP := range ips {
if isPrivateIPAddress(resolvedIP) {
return true
}
}
return false
}
return isPrivateIPAddress(ip)
}
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
// URL验证
if err := validateImageURL(url); err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
// 使用安全的HTTP客户端
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 禁止访问私有IP
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if isPrivateIP(host) {
return nil, fmt.Errorf("access to private IP denied: %s", host)
}
return (&net.Dialer{}).DialContext(ctx, network, addr)
},
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// 添加安全头
req.Header.Set("User-Agent", "CozeStudio/1.0")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
}
// 限制响应大小
const maxImageSize = 10 * 1024 * 1024 // 10MB
limitedReader := io.LimitReader(resp.Body, maxImageSize)
data, err := io.ReadAll(limitedReader)
if err != nil {
return nil, fmt.Errorf("failed to read image content: %w", err)
}
return data, nil
}
walker := func(n ast.Node, entering bool) (ast.WalkStatus, error) {
if !entering {
return ast.WalkContinue, nil
}
switch n.Kind() {
case ast.KindText:
if n.HasChildren() {
break
}
textNode := n.(*ast.Text)
plainText := trim(string(textNode.Segment.Value(b)))
for _, part := range strings.Split(plainText, cs.Separator) {
runes := []rune(part)
for partLength := int64(len(runes)); partLength > 0; partLength = int64(len(runes)) {
pos := min(partLength, cs.ChunkSize-charCount(last.Content))
chunk := runes[:pos]
addSliceContent(string(chunk))
runes = runes[pos:]
if charCount(last.Content) >= cs.ChunkSize {
pushSlice()
}
}
}
case ast.KindImage:
if !ps.ExtractImage {
break
}
imageNode := n.(*ast.Image)
if ps.ExtractImage {
imageURL := string(imageNode.Destination)
if _, err = url.ParseRequestURI(imageURL); err == nil {
sp := strings.Split(imageURL, ".")
if len(sp) == 0 {
return ast.WalkStop, fmt.Errorf("failed to extract image extension, url=%s", imageURL)
}
ext := sp[len(sp)-1]
img, err := downloadImage(ctx, imageURL)
if err != nil {
return ast.WalkStop, fmt.Errorf("failed to download image: %w", err)
}
imgSrc, err := PutImageObject(ctx, storage, ext, GetCreatorIDFromExtraMeta(options.ExtraMeta), img)
if err != nil {
return ast.WalkStop, err
}
if !emptySlice && last.Content != "" {
pushSlice()
} else {
newSlice(false)
}
addSliceContent(fmt.Sprintf("\n%s\n", imgSrc))
if ps.ImageOCR && ocr != nil {
texts, err := ocr.FromBase64(ctx, base64.StdEncoding.EncodeToString(img))
if err != nil {
return ast.WalkStop, fmt.Errorf("failed to perform OCR on image: %w", err)
}
addSliceContent(strings.Join(texts, "\n"))
}
if charCount(last.Content) >= cs.ChunkSize {
pushSlice()
}
} else {
logs.CtxInfof(ctx, "[ParseMarkdown] not a valid image url, skip, got=%s", imageURL)
}
}
}
return ast.WalkContinue, nil
}
newSlice(false)
if err = ast.Walk(node, walker); err != nil {
return nil, err
}
if !emptySlice {
pushSlice()
}
return docs, nil
}
}