Created
January 14, 2021 15:22
-
-
Save icio/0115df3444c68881ceb1251f806ac0e5 to your computer and use it in GitHub Desktop.
Hooking into Go SQL drivers for triggering race conditions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package repo_test | |
import ( | |
"database/sql" | |
"database/sql/driver" | |
"testing" | |
"./repo" | |
) | |
func TestPersonRepo_Update_Conflict(t *testing.T) { | |
var realDB = connectToPostgres() | |
defer realDB.Close() | |
// These are the updates that we want to apply. Each update is going to | |
// happen in parallel and all will complete their reads before any start | |
// their writes. | |
pid := 1 | |
updates := []repo.Person{ | |
{ID: pid, Name: "Paul"}, | |
{ID: pid, Age: 32, | |
// Adding more than the number of retries will result in some workers | |
// may start getting too many serialisation errors from postgres. | |
} | |
// read WaitGroup is used to block all workers until they've all read, | |
// when the write channel will be closed and all reads/writes are unblocked. | |
write := make(chan bool) | |
var read sync.WaitGroup | |
read.Add(len(updates)) | |
go func() { | |
read.Wait() | |
close(write) | |
}() | |
// hookDB lets us synchronise with read and write after reading the account. | |
db := hookDB(realDB, sqlHooks{ | |
QueryPost: func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error) { | |
if strings.Contains(query, "FROM people") { | |
select { | |
case <-write: | |
// The synchronisation has already completed. Carry on. | |
default: | |
// Synchronise post-read/pre-write. | |
read.Done() | |
<-write | |
} | |
} | |
return rows, err | |
}, | |
}) | |
// Write the starting person. | |
r := repo.PersonRepo{DB: db} | |
err := r.Create(repo.Person{ | |
ID: pid, | |
Age: 31, | |
}) | |
if err != nil { | |
t.Fatalf("Creating account: %s", err) | |
} | |
// Have the workers perform the updates. | |
var wg sync.WaitGroup | |
wg.Add(len(updates)) | |
for _, upd := range updates { | |
go func(upd repo.Person) { | |
defer wg.Done() | |
err := r.Update(upd) | |
if err != nil { | |
t.Errorf("PersonRepo.Update(%#v) returned error: %s", upd, err) | |
} | |
}(upd) | |
} | |
read.Wait() // Check synchronisation occurred. | |
wg.Wait() // Wait for the workers to complete. | |
if t.Failed() { | |
return | |
} | |
// Check that all updates were applied. | |
exp := repo.Person{ID: pid, Name: "Paul", Age: 32} | |
act, err := r.Read(pid) | |
if err != nil { | |
t.Fatalf("Reading final account: %s", err) | |
} | |
if diff := cmp.Diff(exp, act); diff != "" { | |
t.Fatalf("Expected account (-) but got (+):\n%s", diff) | |
} | |
} | |
type sqlHooks struct { | |
ExecPost func(query string, args []driver.Value, res driver.Result, err error) (driver.Result, error) | |
QueryPost func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error) | |
} | |
// hookDB returns a *sql.DB that will call hooks.ExecPost for each INSERT/UPDATE | |
// query run, and hooks.QueryPost for each SELECT query run. This works by | |
// wrapping the database connections (driver.Conn) with our own sqlHookConn | |
// which in turn wraps the database statements (driver.Stmt) with our own | |
// sqlHookStmt which invokes the hooks. | |
// | |
// The flow looks something like this: | |
// | |
// var db *sql.DB = realDB | |
// db = hookDB(db, sqlHooks{}) | |
// var tx *sql.Tx = db.BeginTx(...) | |
// | |
// This will now request a new database connection, which ends up chaining | |
// db.Conn() -> sqlHookConnector.Connect -> realDB.Connect -> sqlHookConn. We | |
// then try to perform tx.Exec(`UPDATE ...`) which ends up chaining | |
// sqlHookConn.Prepare -> sqlHookStmt -> sqlHookStmt.Exec -> sqlHooks.ExecPost. | |
// | |
// There are shortcut database/sql/driver interfaces such as Execer that we | |
// purposefully don't implement here to ensure that all queries are routed | |
// through sqlHookStmt and therefore not require us to implement the same logic | |
// in multiple places. | |
func hookDB(db *sql.DB, hooks sqlHooks) *sql.DB { | |
if hooks.ExecPost == nil { | |
hooks.ExecPost = func(query string, args []driver.Value, res driver.Result, err error) (driver.Result, error) { | |
return res, err | |
} | |
} | |
if hooks.QueryPost == nil { | |
hooks.QueryPost = func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error) { | |
return rows, err | |
} | |
} | |
return sql.OpenDB(&sqlHookConnector{hooks, db}) | |
} | |
type sqlHookConnector struct { | |
hooks sqlHooks | |
db *sql.DB | |
} | |
func (c *sqlHookConnector) Connect(ctx context.Context) (driver.Conn, error) { | |
sqlConn, err := c.db.Conn(ctx) | |
if err != nil { | |
return nil, err | |
} | |
var conn driverConnFull | |
err = sqlConn.Raw(func(driverConn interface{}) error { | |
// Apparently we're not supposed to do this. | |
conn = driverConn.(driverConnFull) | |
return nil | |
}) | |
return &sqlHookConn{c.hooks, conn}, err | |
} | |
func (c *sqlHookConnector) Driver() driver.Driver { | |
return c.db.Driver() | |
} | |
type sqlHookConn struct { | |
hooks sqlHooks | |
driverConnFull | |
} | |
type driverConnFull interface { | |
driver.Conn | |
driver.ConnBeginTx | |
} | |
func (c *sqlHookConn) Prepare(query string) (driver.Stmt, error) { | |
stmt, err := c.driverConnFull.Prepare(query) | |
return &sqlHookStmt{c.hooks, query, stmt}, err | |
} | |
type sqlHookStmt struct { | |
hooks sqlHooks | |
query string | |
driver.Stmt | |
} | |
func (s *sqlHookStmt) Exec(args []driver.Value) (driver.Result, error) { | |
r, err := s.Stmt.Exec(args) | |
return s.hooks.ExecPost(s.query, args, r, err) | |
} | |
func (s *sqlHookStmt) Query(args []driver.Value) (driver.Rows, error) { | |
r, err := s.Stmt.Query(args) | |
return s.hooks.QueryPost(s.query, args, r, err) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment