Skip to content

Instantly share code, notes, and snippets.

@derekperkins
Created May 9, 2024 23:18
Show Gist options
  • Save derekperkins/0cbdf7f48f79a3f91948440c6892a457 to your computer and use it in GitHub Desktop.
Save derekperkins/0cbdf7f48f79a3f91948440c6892a457 to your computer and use it in GitHub Desktop.

MySQL

expandArgs

expandArgs expands any named args that are slices into multiple named args, one for each element in the slice. This is useful for queries that use the IN operator, e.g. SELECT * FROM table WHERE id IN (:id)

  • If there are no args, the original stmt and args are returned.
  • If there are no slices to expand, the original stmt and args are returned.
  • If the args are positional, the original stmt and args are returned, as we can't expand positional args.
  • Only slices of primitives are supported. Slices of structs or other types are not supported.
  • Stmts are not supported, as the stmt string can't be modified after Prepare is called

Benchmarked 2024-05-09 on an M3 Max Macbook Pro 2023, 16-inch, 36GB RAM, macOS 14.4.1, Go 1.22.3

goos: darwin
goarch: arm64
pkg: go.nozzle.io/pkg/mysql
Benchmark_expandArgs
Benchmark_expandArgs/no_args
Benchmark_expandArgs/no_args-14         	                                                857257668	         1.386 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/positional:_no_array_arg,_1_arg
Benchmark_expandArgs/positional:_no_array_arg,_1_arg-14         	                        566456900	         2.110 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/positional:_no_array_arg,_1,000_ints
Benchmark_expandArgs/positional:_no_array_arg,_1,000_ints-14    	                        570863935	         2.113 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/positional:_no_array_arg,_1,000_strings
Benchmark_expandArgs/positional:_no_array_arg,_1,000_strings-14 	                        564158204	         2.142 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/positional:_no_array_arg,_100,000_ints
Benchmark_expandArgs/positional:_no_array_arg,_100,000_ints-14  	                        563225262	         2.116 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/positional:_no_array_arg,_100,000_strings
Benchmark_expandArgs/positional:_no_array_arg,_100,000_strings-14         	                563380719	         2.115 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/named:_no_array_arg
Benchmark_expandArgs/named:_no_array_arg-14                               	                230793308	         5.211 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/named:_one_array_arg_with_3_elements
Benchmark_expandArgs/named:_one_array_arg_with_3_elements-14              	                250725536	         4.831 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/named:_one_array_arg_with_1,000_elements
Benchmark_expandArgs/named:_one_array_arg_with_1,000_elements-14          	                248725476	         4.808 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/named:_one_array_arg_with_100,000_elements
Benchmark_expandArgs/named:_one_array_arg_with_100,000_elements-14        	                245566166	         5.083 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/named:_one_array_arg_with_3_elements_and_one_normal_arg
Benchmark_expandArgs/named:_one_array_arg_with_3_elements_and_one_normal_arg-14         	171625336	         7.009 ns/op	       0 B/op	       0 allocs/op
Benchmark_expandArgs/named:_one_array_arg_with_100,000_elements_and_one_normal_arg
Benchmark_expandArgs/named:_one_array_arg_with_100,000_elements_and_one_normal_arg-14   	172396051	         6.967 ns/op	       0 B/op	       0 allocs/op
package mysql
import (
"bytes"
"database/sql"
"reflect"
"slices"
"strconv"
"strings"
"sync"
)
var bufferPool = &sync.Pool{
New: func() any {
return &bytes.Buffer{}
},
}
// getBuffer returns a buffer from the pool.
func getBuffer() (buf *bytes.Buffer) {
return bufferPool.Get().(*bytes.Buffer)
}
// putBuffer returns a buffer to the pool.
// The buffer is reset before it is put back into circulation.
func putBuffer(buf *bytes.Buffer) {
buf.Reset()
bufferPool.Put(buf)
}
// expandArgs expands any named args that are slices into multiple named args, one for each element in the slice.
// This is useful for queries that use the IN operator, e.g. SELECT * FROM table WHERE id IN (:id)
// - If there are no args, the original stmt and args are returned.
// - If there are no slices to expand, the original stmt and args are returned.
// - If the args are positional, the original stmt and args are returned, as we can't expand positional args.
// - Only slices of primitives are supported. Slices of structs or other types are not supported.
// - Stmts are not supported, as the stmt string can't be modified after Prepare is called
func expandArgs(stmt string, args ...any) (string, []any) {
// fast path if there are no args
if len(args) == 0 {
return stmt, args
}
// fast path if the first arg is not a named arg. We can't expand positional args, and named / positional args
// can't be mixed in a single query. We could support this, but it would require a full sql parser to be correct.
// We could check in the loop for every positional arg, but that imposes a performance penalty on every query
// for a feature that is likely not used. We'll let the sql driver handle this case.
_, isNamed := args[0].(sql.NamedArg)
if !isNamed {
return stmt, args
}
// add shared vars to use in the loop, but as nil values, so we don't allocate them if we don't need them
var buf *bytes.Buffer
// expandedArgs is used to store the new named args that are created from expanding the original named args.
// This is purposefully not appending to the original args, since we would be appending inside the loop, which
// we are currently ranging through, which would cause us to range over the new args as well.
var expandedArgs []any
for i, arg := range args {
namedArg := arg.(sql.NamedArg)
// if the arg is not a slice, continue to the next arg
v := reflect.ValueOf(namedArg.Value)
if v.Kind() != reflect.Slice {
continue
}
// grow the expanded args slice to the length of the user slice
expandedArgs = slices.Grow(expandedArgs, v.Len())
if buf == nil {
// if we don't have a buffer, get one from the pool. This only happens once in this function
buf = getBuffer()
} else {
// if we do have a buffer, reset it and reuse it. It should be rare that there are two slices to expand
buf.Reset()
}
// iterate over the slice to:
// 1. append each element to the expanded args
// 2. write a referenced to the named arg to the buffer, using the index as a suffix
buf.WriteByte('(')
for j := 0; j < v.Len(); j++ {
name := namedArg.Name + strconv.Itoa(j)
expandedArgs = append(expandedArgs, sql.Named(name, v.Index(j).Interface()))
buf.WriteByte(':')
buf.WriteString(name)
buf.WriteByte(',')
}
buf.Truncate(buf.Len() - 1) // remove trailing comma
buf.WriteByte(')')
// to be extra safe, we're only replacing named args immediately surrounded by parentheses, so we don't accidentally
// replace a named arg that is part of a larger string. e.g.
// CORRECT: SELECT * FROM table WHERE id IN (:id)
// INCORRECT: SELECT * FROM table WHERE id IN (:id )
stmt = strings.ReplaceAll(stmt, "(:"+namedArg.Name+")", buf.String())
// we could delete the original named arg from the args slice, but that is probably unnecessary. We'll nil out
// the value to make sure it's not used again for both correctness and performance
args[i] = sql.Named(namedArg.Name, nil)
}
// if there are no slices to expand, return the original args
if len(expandedArgs) == 0 {
return stmt, args
}
// return the buffer to the pool. This is purposefully not done with a defer inside the loop. While that is
// technically correct and only calling defer once, it still caused a ~3x slowdown in benchmarks (2ns -> 6ns)
putBuffer(buf)
// append the newly expanded args to the original args. Since this is guaranteed to be a named arg,
// we can safely append the expanded args to the end of the original args without regard to their order
return stmt, append(args, expandedArgs...)
}
package mysql
import (
"database/sql"
"testing"
"github.com/maxatome/go-testdeep/td"
)
func Test_expandArgs(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sql string
args []any
wantSQL string
wantArgs []any
}{
{
"no args",
"SELECT * FROM table",
nil,
"SELECT * FROM table",
nil,
},
{
"positional: no array arg",
"SELECT * FROM table WHERE id = ?",
[]any{1},
"SELECT * FROM table WHERE id = ?",
[]any{1},
},
{
"named: no array arg",
"SELECT * FROM table WHERE id = :id",
[]any{sql.Named("id", 1)},
"SELECT * FROM table WHERE id = :id",
[]any{sql.Named("id", 1)},
},
{
"positional: one array arg",
"SELECT * FROM table WHERE id IN (?)",
[]any{[]int{1, 2, 3}},
"SELECT * FROM table WHERE id IN (?)",
[]any{[]int{1, 2, 3}},
},
{
"named: one array arg",
"SELECT * FROM table WHERE id IN (:ids)",
[]any{sql.Named("ids", []int{1, 2, 3})},
"SELECT * FROM table WHERE id IN (:ids0,:ids1,:ids2)",
[]any{sql.Named("ids", nil), sql.Named("ids0", 1), sql.Named("ids1", 2), sql.Named("ids2", 3)},
},
{
"named: one array arg and one normal arg",
"SELECT * FROM table WHERE id IN (:ids) AND batch_id = :batch_id",
[]any{sql.Named("batch_id", 1), sql.Named("ids", []int{1, 2, 3})},
"SELECT * FROM table WHERE id IN (:ids0,:ids1,:ids2) AND batch_id = :batch_id",
[]any{sql.Named("batch_id", 1), sql.Named("ids", nil), sql.Named("ids0", 1), sql.Named("ids1", 2), sql.Named("ids2", 3)},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotSQL, gotArgs := expandArgs(tt.sql, tt.args...)
td.Cmp(t, gotSQL, tt.wantSQL)
td.Cmp(t, gotArgs, tt.wantArgs)
})
}
}
func Benchmark_expandArgs(b *testing.B) {
benchmarks := []struct {
name string
sql string
args []any
}{
{
"no args",
"SELECT * FROM table",
nil,
},
{
"positional: no array arg, 1 arg",
"SELECT * FROM table WHERE id = ?",
[]any{1},
},
{
"positional: no array arg, 1,000 ints",
"SELECT * FROM table WHERE id = ?",
generateArgs(1, 1_000),
},
{
"positional: no array arg, 1,000 strings",
"SELECT * FROM table WHERE id = ?",
generateArgs("I'm a fun test string", 1000),
},
{
"positional: no array arg, 100,000 ints",
"SELECT * FROM table WHERE id = ?",
generateArgs(1, 100_000),
},
{
"positional: no array arg, 100,000 strings",
"SELECT * FROM table WHERE id = ?",
generateArgs("I'm a fun test string", 100_000),
},
{
"named: no array arg",
"SELECT * FROM table WHERE id = :id",
[]any{sql.Named("id", 1)},
},
{
"named: one array arg with 3 elements",
"SELECT * FROM table WHERE id IN (:id)",
[]any{sql.Named("id", []int{1, 2, 3})},
},
{
"named: one array arg with 1,000 elements",
"SELECT * FROM table WHERE id IN (:id)",
[]any{sql.Named("id", generateArgs(1, 1_000))},
},
{
"named: one array arg with 100,000 elements",
"SELECT * FROM table WHERE id IN (:id)",
[]any{sql.Named("id", generateArgs(1, 100_000))},
},
{
"named: one array arg with 3 elements and one normal arg",
"SELECT * FROM table WHERE id IN (:id) AND batch_id = :batch_id",
[]any{sql.Named("batch_id", 1), sql.Named("id", []int{1, 2, 3})},
},
{
"named: one array arg with 100,000 elements and one normal arg",
"SELECT * FROM table WHERE id IN (:id) AND batch_id = :batch_id",
[]any{sql.Named("batch_id", 1), sql.Named("id", generateArgs(1, 100_000))},
},
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
expandArgs(bm.sql, bm.args...)
}
b.ReportAllocs()
})
}
}
func generateArgs[T any](t T, n int) []any {
args := make([]any, n)
for i := 0; i < n; i++ {
args[i] = t
}
return args
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment