Skip to content

Instantly share code, notes, and snippets.

@thiagozs
Created July 12, 2024 12:57
Show Gist options
  • Save thiagozs/772fd1246ef7f6cbca06f6a1fbf6dc4e to your computer and use it in GitHub Desktop.
Save thiagozs/772fd1246ef7f6cbca06f6a1fbf6dc4e to your computer and use it in GitHub Desktop.
Dynamic scan you query DB to struct golang
package dynscan
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"github.com/google/uuid"
)
// DynamicScan scans a single row or multiple rows into the provided destination struct or slice
func DynamicScan(db *sql.DB, result any, tableName string, dest any) error {
switch res := result.(type) {
case *sql.Row:
columns, err := fetchColumns(db, tableName)
if err != nil {
return fmt.Errorf("error fetching columns: %w", err)
}
return scanRow(res, columns, dest)
case *sql.Rows:
columns, err := res.Columns()
if err != nil {
return fmt.Errorf("error fetching columns from rows: %w", err)
}
return scanRows(res, columns, dest)
default:
return fmt.Errorf("unsupported result type %T", result)
}
}
// scanRow scans a single row into the provided destination struct
func scanRow(row *sql.Row, columns []string, dest any) error {
fieldPtrs, fieldMap, err := prepareFieldPointers(dest, columns)
if err != nil {
return err
}
if err := row.Scan(fieldPtrs...); err != nil {
return fmt.Errorf("error scanning row: %w", err)
}
if err := setStructFields(dest, columns, fieldPtrs, fieldMap); err != nil {
return fmt.Errorf("error setting struct fields: %w", err)
}
return nil
}
// scanRows scans multiple rows into the provided destination slice
func scanRows(rows *sql.Rows, columns []string, dest any) error {
destVal := reflect.ValueOf(dest)
if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Slice {
return fmt.Errorf("dest must be a pointer to a slice")
}
destSlice := destVal.Elem()
elemType := destSlice.Type().Elem()
for rows.Next() {
elem := reflect.New(elemType).Elem()
fieldPtrs, fieldMap, err := prepareFieldPointers(elem.Addr().Interface(), columns)
if err != nil {
return err
}
if err := rows.Scan(fieldPtrs...); err != nil {
return fmt.Errorf("error scanning rows: %w", err)
}
if err := setStructFields(elem.Addr().Interface(), columns, fieldPtrs, fieldMap); err != nil {
return fmt.Errorf("error setting struct fields: %w", err)
}
destSlice.Set(reflect.Append(destSlice, elem))
}
return rows.Err()
}
// prepareFieldPointers prepares pointers for fields based on the column names
func prepareFieldPointers(dest any, columns []string) ([]interface{}, map[string]int, error) {
v := reflect.ValueOf(dest).Elem()
t := v.Type()
fieldPtrs := make([]interface{}, len(columns))
fieldMap := make(map[string]int)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldMap[cleanTag(field.Tag.Get("json"))] = i
}
for i, column := range columns {
idx, ok := fieldMap[column]
if !ok {
fieldPtrs[i] = new(interface{}) // Unmapped columns
continue
}
field := v.Field(idx)
fieldType := field.Type()
switch fieldType {
case reflect.TypeOf(time.Time{}):
fieldPtrs[i] = new(sql.NullTime)
case reflect.TypeOf(bool(false)):
fieldPtrs[i] = new(sql.NullBool)
case reflect.TypeOf(""):
fieldPtrs[i] = new(sql.NullString)
case reflect.TypeOf(uuid.UUID{}):
fieldPtrs[i] = new(sql.NullString)
case reflect.TypeOf([]byte{}):
fieldPtrs[i] = new([]byte)
default:
if fieldType.Kind() == reflect.Ptr {
fieldPtrs[i] = reflect.New(fieldType.Elem()).Interface()
} else {
fieldPtrs[i] = field.Addr().Interface()
}
}
}
return fieldPtrs, fieldMap, nil
}
// setStructFields sets the scanned values to the struct fields
func setStructFields(dest any, columns []string, fieldPtrs []interface{}, fieldMap map[string]int) error {
v := reflect.ValueOf(dest).Elem()
for i, column := range columns {
idx, ok := fieldMap[column]
if !ok {
continue
}
field := v.Field(idx)
fieldType := field.Type()
if !field.CanSet() {
return fmt.Errorf("cannot set field %s", column)
}
switch ptr := fieldPtrs[i].(type) {
case *sql.NullTime:
if ptr.Valid {
field.Set(reflect.ValueOf(ptr.Time))
} else {
field.Set(reflect.ValueOf(time.Time{}))
}
case *sql.NullBool:
if ptr.Valid {
field.SetBool(ptr.Bool)
} else {
field.SetBool(false)
}
case *sql.NullString:
if ptr.Valid {
if fieldType == reflect.TypeOf(uuid.UUID{}) {
u, err := uuid.Parse(ptr.String)
if err != nil {
return fmt.Errorf("error parsing UUID: %w", err)
}
field.Set(reflect.ValueOf(u))
} else {
field.SetString(ptr.String)
}
} else {
field.Set(reflect.Zero(fieldType))
}
case *[]byte:
field.SetBytes(*ptr)
default:
if field.Kind() == reflect.Ptr {
if reflect.ValueOf(ptr).Elem().IsValid() {
field.Set(reflect.ValueOf(ptr).Elem())
} else {
field.Set(reflect.Zero(fieldType))
}
} else {
field.Set(reflect.ValueOf(ptr).Elem())
}
}
}
return nil
}
// fetchColumns fetches the columns from a dummy query
func fetchColumns(db *sql.DB, tableName string) ([]string, error) {
query := fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
return rows.Columns()
}
func cleanTag(tag string) string {
return strings.Split(tag, ",")[0]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment