fix: When agents use data tables, different users influence each other (#565)

This commit is contained in:
liuyunchao-1998
2025-08-06 16:03:21 +08:00
committed by GitHub
parent 8b91a640b9
commit 60285ca014
4 changed files with 302 additions and 1 deletions

View File

@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
_ "github.com/pingcap/tidb/pkg/parser/test_driver"
"github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
@ -411,3 +412,81 @@ func (p *Impl) GetInsertDataNums(sql string) (int, error) {
return len(insert.Lists), nil
}
func (p *Impl) AppendSQLFilter(sql string, op sqlparser.SQLFilterOp, filter string) (string, error) {
if sql == "" {
return "", fmt.Errorf("empty SQL statement")
}
if op == "" || (op != sqlparser.SQLFilterOpAnd && op != sqlparser.SQLFilterOpOr) {
return "", fmt.Errorf("invalid filter operator: %s", op)
}
if filter == "" {
return "", fmt.Errorf("empty filter condition")
}
stmtNode, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
if err != nil {
return "", fmt.Errorf("failed to parse SQL: %v", err)
}
// extract WHERE clause
var originalWhere ast.ExprNode
switch stmt := stmtNode.(type) {
case *ast.SelectStmt:
originalWhere = stmt.Where
case *ast.UpdateStmt:
originalWhere = stmt.Where
case *ast.DeleteStmt:
originalWhere = stmt.Where
default:
return "", fmt.Errorf("append filter condition failed: only support SELECT/UPDATE/DELETE")
}
tmpSQL := fmt.Sprintf("SELECT * FROM tmp WHERE %s", filter)
tmpNode, err := p.parser.ParseOneStmt(tmpSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
if err != nil {
return "", fmt.Errorf("parse filter condition failed: %v", err)
}
newExpr := tmpNode.(*ast.SelectStmt).Where
mergedExpr := mergeExpr(originalWhere, newExpr, op)
// update AST
switch stmt := stmtNode.(type) {
case *ast.SelectStmt:
stmt.Where = mergedExpr
case *ast.UpdateStmt:
stmt.Where = mergedExpr
case *ast.DeleteStmt:
stmt.Where = mergedExpr
}
// regenerate SQL
var sb strings.Builder
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset | format.RestoreNameBackQuotes
restoreCtx := format.NewRestoreCtx(flags, &sb)
if err := stmtNode.Restore(restoreCtx); err != nil {
return "", fmt.Errorf("gen SQL failed: %v", err)
}
return sb.String(), nil
}
func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode {
if left == nil {
return right
}
if right == nil {
return left
}
switch op {
case sqlparser.SQLFilterOpAnd:
return &ast.BinaryOperationExpr{
Op: opcode.LogicAnd,
L: left,
R: right,
}
case sqlparser.SQLFilterOpOr:
return &ast.BinaryOperationExpr{
Op: opcode.LogicOr,
L: left,
R: right,
}
default:
return nil
}
}