Last active
August 5, 2024 13:41
-
-
Save omerkaya1/fc272877536c4ed4af9f36e2ae57bf08 to your computer and use it in GitHub Desktop.
Golang PostgreSQL connection pool example Go ^1.18
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"github.com/prometheus/client_golang/prometheus" | |
"github.com/prometheus/client_golang/prometheus/promauto" | |
) | |
var queryTime = promauto.NewHistogramVec( | |
prometheus.HistogramOpts{ | |
Name: "db_client_query_execution_time", | |
Help: "The query execution time", | |
}, | |
[]string{"method"}, | |
) | |
func queryTimeObserver(method string) prometheus.Observer { | |
return queryTime.WithLabelValues(method) | |
} | |
// DB conn pool metrics. | |
var ( | |
acquireCount = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_acquire", | |
Help: "The cumulative count of successful acquires from the pool", | |
}, | |
) | |
acquiredCount = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_acquired", | |
Help: "The number of currently acquired connections in the pool", | |
}, | |
) | |
canceledAcquireCount = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_canceled_acquire_count", | |
Help: "The cumulative count of acquires from the pool that were canceled by a context", | |
}, | |
) | |
constructingConns = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_constructing", | |
Help: "The number of conns with construction in progress in the pool", | |
}, | |
) | |
emptyAcquireCount = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_empty_acquire_count", | |
Help: "The cumulative count of successful acquires from the pool that waited for a resource to be " + | |
"released or constructed because the pool was empty", | |
}, | |
) | |
idleConns = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_idle_conns", | |
Help: "The number of currently idle conns in the pool", | |
}, | |
) | |
maxConns = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_max_conns", | |
Help: "The maximum size of the pool", | |
}, | |
) | |
newConnsCount = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_new_conns_count", | |
Help: "The cumulative count of new connections opened", | |
}, | |
) | |
maxLifetimeDestroyCount = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_max_lifetime_destroy_count", | |
Help: "The cumulative count of connections destroyed because they exceeded MaxConnLifetime", | |
}, | |
) | |
maxIdleDestroyCount = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_max_idle_destroy_count", | |
Help: "The cumulative count of connections destroyed because they exceeded MaxConnIdleTime", | |
}, | |
) | |
totalConns = promauto.NewGauge( | |
prometheus.GaugeOpts{ | |
Name: "db_client_conn_pool_total_conns", | |
Help: "The total number of resources currently in the pool", | |
}, | |
) | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"context" | |
"database/sql" | |
"net" | |
"net/url" | |
"runtime" | |
"strconv" | |
"strings" | |
"time" | |
"github.com/exaring/otelpgx" | |
"github.com/jackc/pgx/v5" | |
"github.com/jackc/pgx/v5/pgxpool" | |
"github.com/jackc/pgx/v5/stdlib" | |
"github.com/pkg/errors" | |
"go.uber.org/zap" | |
) | |
//go:generate mockgen -destination=./mock/pgx.go -package=mock -typed github.com/jackc/pgx/v5 Rows,Row,Tx | |
//go:generate mockgen -source=./pool.go -destination=./mock/pool.go -typed -package=mock | |
type ( | |
ConnPool interface { | |
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) | |
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row | |
Begin(ctx context.Context) (pgx.Tx, error) | |
} | |
PGX struct { | |
pool *pgxpool.Pool | |
} | |
Config struct { | |
Tracing bool | |
ConnRetry int | |
MaxOpenConns int | |
MinOpenConns int | |
StatementCacheCapacity int | |
DescriptionCacheCapacity int | |
ConnTimeout time.Duration | |
MaxOpenConnTTL time.Duration | |
MaxIdleConnTTL time.Duration | |
MaxConnLifetimeJitterTTL time.Duration | |
QueryMode string | |
User string | |
Password string | |
Host string | |
Port string | |
Name string | |
} | |
) | |
func (cfg Config) DSN() string { // nolint:gocritic | |
query := make(url.Values) | |
if cfg.MaxOpenConns > 0 { | |
query.Set("pool_max_conns", strconv.Itoa(cfg.MaxOpenConns)) | |
} | |
if cfg.MinOpenConns > 0 { | |
query.Set("pool_min_conns", strconv.Itoa(cfg.MinOpenConns)) | |
} | |
if cfg.MaxOpenConnTTL > 0 { | |
query.Set("pool_max_conn_lifetime", cfg.MaxOpenConnTTL.String()) | |
} | |
if cfg.MaxIdleConnTTL > 0 { | |
query.Set("pool_max_conn_idle_time", cfg.MaxIdleConnTTL.String()) | |
} | |
if cfg.MaxIdleConnTTL > 0 { | |
query.Set("pool_max_conn_idle_time", cfg.MaxIdleConnTTL.String()) | |
} | |
if cfg.MaxConnLifetimeJitterTTL > 0 { | |
query.Set("pool_max_conn_lifetime_jitter", cfg.MaxConnLifetimeJitterTTL.String()) | |
} | |
if cfg.QueryMode != "" { | |
query.Set("default_query_exec_mode", cfg.QueryMode) | |
} | |
if cfg.StatementCacheCapacity >= 0 { | |
query.Set("statement_cache_capacity", strconv.Itoa(cfg.StatementCacheCapacity)) | |
} | |
if cfg.DescriptionCacheCapacity >= 0 { | |
query.Set("description_cache_capacity", strconv.Itoa(cfg.DescriptionCacheCapacity)) | |
} | |
dsn := url.URL{ | |
Scheme: "postgres", | |
User: url.UserPassword(cfg.User, cfg.Password), | |
Host: net.JoinHostPort(cfg.Host, cfg.Port), | |
Path: cfg.Name, | |
RawQuery: query.Encode(), | |
} | |
return dsn.String() | |
} | |
func NewPGX(ctx context.Context, cfg Config) (*PGX, error) { //nolint:gocritic | |
pCfg, err := pgxpool.ParseConfig(cfg.DSN()) | |
if err != nil { | |
return nil, errors.Wrap(err, "create db conn pool") | |
} | |
if cfg.Tracing { | |
pCfg.ConnConfig.Tracer = otelpgx.NewTracer( | |
otelpgx.WithIncludeQueryParameters(), | |
otelpgx.WithTrimSQLInSpanName(), | |
) | |
} | |
pool, err := pgxpool.NewWithConfig(ctx, pCfg) | |
if err != nil { | |
return nil, errors.Wrap(err, "create db conn pool") | |
} | |
if err = PingConnection(ctx, &cfg, func(pingCtx context.Context) error { | |
return pool.Ping(pingCtx) | |
}); err != nil { | |
return nil, errors.Wrap(err, "create db conn pool") | |
} | |
return &PGX{ | |
pool: pool, | |
}, nil | |
} | |
func PingConnection(ctx context.Context, cfg *Config, pinger func(ctx context.Context) error) error { | |
ticker := time.NewTicker(cfg.ConnTimeout) | |
defer ticker.Stop() | |
var err error | |
for i := 0; i < cfg.ConnRetry; i++ { | |
switch err = pinger(ctx); { | |
case err == nil: | |
return nil | |
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): | |
return errors.Wrap(err, "ping database connection") | |
default: | |
zap.L().Error("failed to ping database connection", zap.Error(err)) | |
} | |
select { | |
case <-ctx.Done(): | |
return errors.Wrap(ctx.Err(), "ping database connection") | |
case <-ticker.C: | |
} | |
} | |
return errors.Wrap(err, "ping database connection") | |
} | |
func (d *PGX) Close() error { | |
d.pool.Close() | |
return nil | |
} | |
func (d *PGX) DB() *sql.DB { | |
return stdlib.OpenDBFromPool(d.pool) | |
} | |
func (d *PGX) CollectMetrics(ctx context.Context) error { | |
errChan := make(chan error, 1) | |
defer close(errChan) | |
go func() { | |
s := d.pool.Stat() | |
totalConns.Set(float64(s.TotalConns())) | |
acquireCount.Set(float64(s.AcquireCount())) | |
acquiredCount.Set(float64(s.AcquiredConns())) | |
canceledAcquireCount.Set(float64(s.CanceledAcquireCount())) | |
constructingConns.Set(float64(s.ConstructingConns())) | |
emptyAcquireCount.Set(float64(s.EmptyAcquireCount())) | |
idleConns.Set(float64(s.IdleConns())) | |
maxConns.Set(float64(s.MaxConns())) | |
newConnsCount.Set(float64(s.NewConnsCount())) | |
maxLifetimeDestroyCount.Set(float64(s.MaxLifetimeDestroyCount())) | |
maxIdleDestroyCount.Set(float64(s.MaxIdleDestroyCount())) | |
errChan <- nil | |
}() | |
select { | |
case err := <-errChan: | |
return err | |
case <-ctx.Done(): | |
return ctx.Err() | |
} | |
} | |
func (d *PGX) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { | |
var ( | |
fn = funcName() | |
t = time.Now() | |
rows, err = d.pool.Query(ctx, sql, args...) | |
since = time.Since(t) | |
) | |
queryTimeObserver(fn).Observe(since.Seconds()) | |
if err != nil { | |
zap.L().Error("failure to execute query", zap.Error(errors.WithStack(err))) | |
return nil, err | |
} | |
return &RowsWithMetrics{ | |
Rows: rows, | |
callerFnName: fn, | |
}, nil | |
} | |
func (d *PGX) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { | |
var ( | |
fn = funcName() | |
t = time.Now() | |
row = d.pool.QueryRow(ctx, sql, args...) | |
since = time.Since(t) | |
) | |
queryTimeObserver(fn).Observe(since.Seconds()) | |
return &RowWithMetrics{ | |
Row: row, | |
callerFnName: fn, | |
} | |
} | |
func (d *PGX) Begin(ctx context.Context) (pgx.Tx, error) { | |
var ( | |
fn = funcName() | |
t = time.Now() | |
tx, err = d.pool.Begin(ctx) | |
) | |
if err != nil { | |
zap.L().Error("failure to begin a transaction", zap.Error(errors.WithStack(err))) | |
return nil, err | |
} | |
return &TxWithMetrics{ | |
Tx: tx, | |
callerFnName: fn, | |
start: t, | |
}, nil | |
} | |
const ( | |
skipN = 2 | |
delimiter = "." | |
) | |
func funcName() string { | |
pc, _, _, _ := runtime.Caller(skipN) // nolint:dogsled | |
name := runtime.FuncForPC(pc).Name() | |
result := strings.Split(name, delimiter) | |
return result[len(result)-1] | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"github.com/jackc/pgx/v5" | |
"github.com/pkg/errors" | |
"go.uber.org/zap" | |
) | |
type ( | |
RowsWithMetrics struct { | |
pgx.Rows | |
callerFnName string | |
} | |
RowWithMetrics struct { | |
pgx.Row | |
callerFnName string | |
} | |
) | |
func (r *RowsWithMetrics) Next() bool { | |
ok := r.Rows.Next() | |
return ok | |
} | |
func (r *RowsWithMetrics) Scan(args ...any) error { | |
err := r.Rows.Scan(args...) | |
if err != nil { | |
zap.L().Error("failure to scan into the provided args", zap.Error(errors.WithStack(err))) | |
} | |
return err | |
} | |
func (r *RowsWithMetrics) Values() ([]interface{}, error) { | |
result, err := r.Rows.Values() | |
if err != nil { | |
zap.L().Error("failure to get row values", zap.Error(errors.WithStack(err))) | |
} | |
return result, err | |
} | |
func (r *RowsWithMetrics) Err() error { | |
err := r.Rows.Err() | |
if err != nil { | |
zap.L().Error("failure to iterate over returned rows", zap.Error(errors.WithStack(err))) | |
} | |
return err | |
} | |
func (r *RowWithMetrics) Scan(args ...any) error { | |
err := r.Row.Scan(args...) | |
if err != nil { | |
zap.L().Error("failure to scan into the provided args", zap.Error(errors.WithStack(err))) | |
} | |
return err | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"context" | |
"time" | |
"github.com/jackc/pgx/v5" | |
"github.com/jackc/pgx/v5/pgconn" | |
"github.com/pkg/errors" | |
"go.uber.org/zap" | |
) | |
type TxWithMetrics struct { | |
pgx.Tx | |
callerFnName string | |
start time.Time | |
} | |
func (t *TxWithMetrics) Commit(ctx context.Context) error { | |
var ( | |
err = t.Tx.Commit(ctx) | |
since = time.Since(t.start) | |
) | |
if err != nil { | |
zap.L().Error("failure to commit a transaction", zap.Error(errors.WithStack(err))) | |
} | |
queryTimeObserver(t.callerFnName).Observe(since.Seconds()) | |
return err | |
} | |
func (t *TxWithMetrics) Rollback(ctx context.Context) error { | |
var ( | |
err = t.Tx.Rollback(ctx) | |
since = time.Since(t.start) | |
) | |
if err != nil && !errors.Is(err, pgx.ErrTxClosed) { | |
zap.L().Error("failure to rollback a transaction", zap.Error(errors.WithStack(err))) | |
} | |
queryTimeObserver(t.callerFnName).Observe(since.Seconds()) | |
return nil | |
} | |
func (t *TxWithMetrics) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { | |
result, err := t.Tx.Exec(ctx, sql, args...) | |
if err != nil { | |
zap.L().Error("failure to execute query", zap.Error(errors.WithStack(err))) | |
return pgconn.CommandTag{}, err | |
} | |
return result, nil | |
} | |
func (t *TxWithMetrics) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { | |
rows, err := t.Tx.Query(ctx, sql, args...) | |
if err != nil { | |
zap.L().Error("failure to execute query", zap.Error(errors.WithStack(err))) | |
return nil, err | |
} | |
return &RowsWithMetrics{ | |
Rows: rows, | |
callerFnName: t.callerFnName, | |
}, nil | |
} | |
func (t *TxWithMetrics) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { | |
return &RowWithMetrics{ | |
Row: t.Tx.QueryRow(ctx, sql, args...), | |
callerFnName: t.callerFnName, | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment