Created
August 14, 2019 13:50
-
-
Save Ehco1996/ab6caeac1a6bca1fa2138afebb9ff205 to your computer and use it in GitHub Desktop.
过滤掉除了select之外的语句,并且将所有select的语句加上/修改成 `limit 1`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"log" | |
"strings" | |
"github.com/pingcap/parser" | |
"github.com/pingcap/parser/ast" | |
"github.com/pingcap/parser/format" | |
_ "github.com/pingcap/tidb/types/parser_driver" | |
) | |
// Rewrite sql Rewrite | |
type Rewrite struct { | |
SQL string | |
NewSQL string | |
Stmt ast.StmtNode | |
} | |
// NewRewrite Func | |
func NewRewrite(sql, charset, collation string) *Rewrite { | |
p := parser.New() | |
stmtNode, err := p.ParseOneStmt(sql, charset, collation) | |
if err != nil { | |
log.Fatal("error...", err) | |
} | |
return &Rewrite{ | |
SQL: sql, | |
Stmt: stmtNode, | |
} | |
} | |
func newLimit(val int) *ast.Limit { | |
limit := ast.Limit{ | |
Count: ast.NewValueExpr(val), | |
} | |
return &limit | |
} | |
type checkLimitVisitor struct{} | |
func (clv *checkLimitVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { | |
switch node := in.(type) { | |
case *ast.Limit: | |
count := ast.NewValueExpr(1) | |
node.Count = count | |
return node, false | |
case *ast.SelectStmt: | |
node.Limit = newLimit(1) | |
} | |
return in, true | |
} | |
func (clv *checkLimitVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { | |
return in, true | |
} | |
func (rw *Rewrite) forceSelectLimit1() *Rewrite { | |
if rw.Stmt == nil { | |
return rw | |
} | |
foundSelect := false | |
switch stmt := rw.Stmt.(type) { | |
case *ast.SelectStmt: | |
v := checkLimitVisitor{} | |
stmt.Accept(&v) | |
foundSelect = true | |
} | |
if foundSelect { | |
var sb strings.Builder | |
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) | |
rw.Stmt.Restore(ctx) | |
rw.NewSQL = sb.String() | |
} | |
return rw | |
} | |
func main() { | |
sql1 := "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.fid WHERE t1.c>100 limit 100;" | |
rw := NewRewrite(sql1, "", "") | |
rw.forceSelectLimit1() | |
fmt.Println(rw.NewSQL) | |
// OUT: SELECT `t1`.`a`,`t2`.`b` FROM `t1` JOIN `t2` ON `t1`.`id`=`t2`.`fid` WHERE `t1`.`c`>100 LIMIT 1 | |
sql2 := "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.fid WHERE t1.c>100;" | |
rw = NewRewrite(sql2, "", "") | |
rw.forceSelectLimit1() | |
fmt.Println(rw.NewSQL) | |
// OUT: SELECT `t1`.`a`,`t2`.`b` FROM `t1` JOIN `t2` ON `t1`.`id`=`t2`.`fid` WHERE `t1`.`c`>100 LIMIT 1 | |
sql3 := "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste';" | |
rw = NewRewrite(sql3, "", "") | |
rw.forceSelectLimit1() | |
fmt.Println(rw.NewSQL) | |
// OUT: "" | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
如果需要将所有查询参数替换为? 需要怎么处理呢