Last active
July 22, 2020 13:09
-
-
Save renthraysk/7bbd7ca91a1ceed4617cc264d179d30e to your computer and use it in GitHub Desktop.
Experimenting with go2 generics and databases
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"log" | |
"database/sql" | |
"github.com/go-sql-driver/mysql" | |
) | |
type Collector[type T] interface { | |
Add(T) | |
} | |
type Value interface { | |
} | |
type Slice[type T Value] []T | |
func (s *Slice[T])Add(t T) { *s = append(*s, t) } | |
type IdentityMap[type K comparable, V Value] struct { | |
m map[K]V | |
f func(V) K | |
} | |
func NewIdentityMap[K comparable, V Value](f func(V) K) *IdentityMap[K, V] { | |
return &IdentityMap[K, V]{m: make(map[K]V), f: f} | |
} | |
func (i *IdentityMap[K, V])Add(v V) { | |
i.m[i.f(v)] = v | |
} | |
func (i *IdentityMap[K, V]) Map() map[K]V { return i.m } | |
type Mapper[type T] func(s Scanner, t T) error | |
type Repository[type T] struct { | |
db *sql.DB | |
} | |
func (r *Repository[T]) scan(rows *sql.Rows, m Mapper[*T], collect Collector[*T]) error { | |
for rows.Next() { | |
t := new(T) | |
if err := m(rows, t); err != nil { | |
return err | |
} | |
collect.Add(t) | |
} | |
return rows.Err() | |
} | |
func (r *Repository[T]) query(query string, m Mapper[*T], collector Collector[*T]) error { | |
rows, err := r.db.Query(query) | |
if err != nil { | |
return err | |
} | |
if err := r.scan(rows, m, collector); err != nil { | |
rows.Close() | |
return err | |
} | |
return rows.Close() | |
} | |
type Scanner interface { | |
Scan(...interface{}) error | |
} | |
type Actor struct { | |
ActorID uint16 | |
FirstName string | |
LastName string | |
} | |
type ActorRepository struct { | |
r Repository[Actor] | |
} | |
func (ActorRepository) key(a *Actor) uint16 { return a.ActorID } | |
func (ActorRepository) scan(s Scanner, a *Actor) error { | |
return s.Scan(&a.ActorID, &a.FirstName, &a.LastName) | |
} | |
func (r *ActorRepository) NewMap() *IdentityMap[uint16, *Actor] { | |
return NewIdentityMap[uint16, *Actor](r.key) | |
} | |
func (r *ActorRepository) SelectAll(collector Collector[*Actor]) error { | |
return r.r.query("SELECT `actor_id`, `first_name`, `last_name` FROM `actor` LIMIT 10", r.scan, collector) | |
} | |
func getDB() (*sql.DB, error) { | |
cfg := mysql.Config{ | |
Net: "unix", | |
Addr: "/var/run/mysqld/mysqld.sock", | |
User: "test", | |
Passwd: "KfD+WHMw-D=gf2au", | |
DBName: "sakila", | |
AllowNativePasswords: true, | |
} | |
return sql.Open("mysql", cfg.FormatDSN()) | |
} | |
func main() { | |
db, err := getDB() | |
if err != nil { | |
log.Fatalf("failed to open db: %v", err) | |
} | |
defer db.Close() | |
r := ActorRepository{r: Repository[Actor]{db: db}} | |
{ | |
fmt.Println("--- Slice") | |
var actors Slice[*Actor] | |
if err := r.SelectAll(&actors); err != nil { | |
log.Fatalf("select actors failed: %v", err) | |
} | |
for _, a := range actors { | |
fmt.Println(a.FirstName) | |
} | |
} | |
{ | |
fmt.Println("--- Map") | |
actors := r.NewMap() | |
if err := r.SelectAll(actors); err != nil { | |
log.Fatalf("select actors failed: %v", err) | |
} | |
for _, a := range actors.Map() { | |
fmt.Println(a.FirstName) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment