fix: When agents use data tables, different users influence each other (#565)
This commit is contained in:
@ -24,6 +24,7 @@ import (
|
||||
"github.com/pingcap/tidb/pkg/parser/ast"
|
||||
"github.com/pingcap/tidb/pkg/parser/format"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
"github.com/pingcap/tidb/pkg/parser/opcode"
|
||||
_ "github.com/pingcap/tidb/pkg/parser/test_driver"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
||||
@ -411,3 +412,81 @@ func (p *Impl) GetInsertDataNums(sql string) (int, error) {
|
||||
|
||||
return len(insert.Lists), nil
|
||||
}
|
||||
func (p *Impl) AppendSQLFilter(sql string, op sqlparser.SQLFilterOp, filter string) (string, error) {
|
||||
if sql == "" {
|
||||
return "", fmt.Errorf("empty SQL statement")
|
||||
}
|
||||
if op == "" || (op != sqlparser.SQLFilterOpAnd && op != sqlparser.SQLFilterOpOr) {
|
||||
return "", fmt.Errorf("invalid filter operator: %s", op)
|
||||
}
|
||||
if filter == "" {
|
||||
return "", fmt.Errorf("empty filter condition")
|
||||
}
|
||||
stmtNode, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
||||
}
|
||||
// extract WHERE clause
|
||||
var originalWhere ast.ExprNode
|
||||
switch stmt := stmtNode.(type) {
|
||||
case *ast.SelectStmt:
|
||||
originalWhere = stmt.Where
|
||||
case *ast.UpdateStmt:
|
||||
originalWhere = stmt.Where
|
||||
case *ast.DeleteStmt:
|
||||
originalWhere = stmt.Where
|
||||
default:
|
||||
return "", fmt.Errorf("append filter condition failed: only support SELECT/UPDATE/DELETE")
|
||||
}
|
||||
tmpSQL := fmt.Sprintf("SELECT * FROM tmp WHERE %s", filter)
|
||||
tmpNode, err := p.parser.ParseOneStmt(tmpSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse filter condition failed: %v", err)
|
||||
}
|
||||
newExpr := tmpNode.(*ast.SelectStmt).Where
|
||||
mergedExpr := mergeExpr(originalWhere, newExpr, op)
|
||||
// update AST
|
||||
switch stmt := stmtNode.(type) {
|
||||
case *ast.SelectStmt:
|
||||
stmt.Where = mergedExpr
|
||||
case *ast.UpdateStmt:
|
||||
stmt.Where = mergedExpr
|
||||
case *ast.DeleteStmt:
|
||||
stmt.Where = mergedExpr
|
||||
}
|
||||
|
||||
// regenerate SQL
|
||||
var sb strings.Builder
|
||||
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset | format.RestoreNameBackQuotes
|
||||
restoreCtx := format.NewRestoreCtx(flags, &sb)
|
||||
if err := stmtNode.Restore(restoreCtx); err != nil {
|
||||
return "", fmt.Errorf("gen SQL failed: %v", err)
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode {
|
||||
if left == nil {
|
||||
return right
|
||||
}
|
||||
if right == nil {
|
||||
return left
|
||||
}
|
||||
|
||||
switch op {
|
||||
case sqlparser.SQLFilterOpAnd:
|
||||
return &ast.BinaryOperationExpr{
|
||||
Op: opcode.LogicAnd,
|
||||
L: left,
|
||||
R: right,
|
||||
}
|
||||
case sqlparser.SQLFilterOpOr:
|
||||
return &ast.BinaryOperationExpr{
|
||||
Op: opcode.LogicOr,
|
||||
L: left,
|
||||
R: right,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user