Skip to content

Instantly share code, notes, and snippets.

@omerkaya1
Last active August 5, 2024 13:41
Show Gist options
  • Save omerkaya1/fc272877536c4ed4af9f36e2ae57bf08 to your computer and use it in GitHub Desktop.
Save omerkaya1/fc272877536c4ed4af9f36e2ae57bf08 to your computer and use it in GitHub Desktop.
Golang PostgreSQL connection pool example Go ^1.18
package db
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var queryTime = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "db_client_query_execution_time",
Help: "The query execution time",
},
[]string{"method"},
)
func queryTimeObserver(method string) prometheus.Observer {
return queryTime.WithLabelValues(method)
}
// DB conn pool metrics.
var (
acquireCount = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_acquire",
Help: "The cumulative count of successful acquires from the pool",
},
)
acquiredCount = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_acquired",
Help: "The number of currently acquired connections in the pool",
},
)
canceledAcquireCount = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_canceled_acquire_count",
Help: "The cumulative count of acquires from the pool that were canceled by a context",
},
)
constructingConns = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_constructing",
Help: "The number of conns with construction in progress in the pool",
},
)
emptyAcquireCount = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_empty_acquire_count",
Help: "The cumulative count of successful acquires from the pool that waited for a resource to be " +
"released or constructed because the pool was empty",
},
)
idleConns = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_idle_conns",
Help: "The number of currently idle conns in the pool",
},
)
maxConns = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_max_conns",
Help: "The maximum size of the pool",
},
)
newConnsCount = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_new_conns_count",
Help: "The cumulative count of new connections opened",
},
)
maxLifetimeDestroyCount = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_max_lifetime_destroy_count",
Help: "The cumulative count of connections destroyed because they exceeded MaxConnLifetime",
},
)
maxIdleDestroyCount = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_max_idle_destroy_count",
Help: "The cumulative count of connections destroyed because they exceeded MaxConnIdleTime",
},
)
totalConns = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "db_client_conn_pool_total_conns",
Help: "The total number of resources currently in the pool",
},
)
)
package db
import (
"context"
"database/sql"
"net"
"net/url"
"runtime"
"strconv"
"strings"
"time"
"github.com/exaring/otelpgx"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/pkg/errors"
"go.uber.org/zap"
)
//go:generate mockgen -destination=./mock/pgx.go -package=mock -typed github.com/jackc/pgx/v5 Rows,Row,Tx
//go:generate mockgen -source=./pool.go -destination=./mock/pool.go -typed -package=mock
type (
ConnPool interface {
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Begin(ctx context.Context) (pgx.Tx, error)
}
PGX struct {
pool *pgxpool.Pool
}
Config struct {
Tracing bool
ConnRetry int
MaxOpenConns int
MinOpenConns int
StatementCacheCapacity int
DescriptionCacheCapacity int
ConnTimeout time.Duration
MaxOpenConnTTL time.Duration
MaxIdleConnTTL time.Duration
MaxConnLifetimeJitterTTL time.Duration
QueryMode string
User string
Password string
Host string
Port string
Name string
}
)
func (cfg Config) DSN() string { // nolint:gocritic
query := make(url.Values)
if cfg.MaxOpenConns > 0 {
query.Set("pool_max_conns", strconv.Itoa(cfg.MaxOpenConns))
}
if cfg.MinOpenConns > 0 {
query.Set("pool_min_conns", strconv.Itoa(cfg.MinOpenConns))
}
if cfg.MaxOpenConnTTL > 0 {
query.Set("pool_max_conn_lifetime", cfg.MaxOpenConnTTL.String())
}
if cfg.MaxIdleConnTTL > 0 {
query.Set("pool_max_conn_idle_time", cfg.MaxIdleConnTTL.String())
}
if cfg.MaxIdleConnTTL > 0 {
query.Set("pool_max_conn_idle_time", cfg.MaxIdleConnTTL.String())
}
if cfg.MaxConnLifetimeJitterTTL > 0 {
query.Set("pool_max_conn_lifetime_jitter", cfg.MaxConnLifetimeJitterTTL.String())
}
if cfg.QueryMode != "" {
query.Set("default_query_exec_mode", cfg.QueryMode)
}
if cfg.StatementCacheCapacity >= 0 {
query.Set("statement_cache_capacity", strconv.Itoa(cfg.StatementCacheCapacity))
}
if cfg.DescriptionCacheCapacity >= 0 {
query.Set("description_cache_capacity", strconv.Itoa(cfg.DescriptionCacheCapacity))
}
dsn := url.URL{
Scheme: "postgres",
User: url.UserPassword(cfg.User, cfg.Password),
Host: net.JoinHostPort(cfg.Host, cfg.Port),
Path: cfg.Name,
RawQuery: query.Encode(),
}
return dsn.String()
}
func NewPGX(ctx context.Context, cfg Config) (*PGX, error) { //nolint:gocritic
pCfg, err := pgxpool.ParseConfig(cfg.DSN())
if err != nil {
return nil, errors.Wrap(err, "create db conn pool")
}
if cfg.Tracing {
pCfg.ConnConfig.Tracer = otelpgx.NewTracer(
otelpgx.WithIncludeQueryParameters(),
otelpgx.WithTrimSQLInSpanName(),
)
}
pool, err := pgxpool.NewWithConfig(ctx, pCfg)
if err != nil {
return nil, errors.Wrap(err, "create db conn pool")
}
if err = PingConnection(ctx, &cfg, func(pingCtx context.Context) error {
return pool.Ping(pingCtx)
}); err != nil {
return nil, errors.Wrap(err, "create db conn pool")
}
return &PGX{
pool: pool,
}, nil
}
func PingConnection(ctx context.Context, cfg *Config, pinger func(ctx context.Context) error) error {
ticker := time.NewTicker(cfg.ConnTimeout)
defer ticker.Stop()
var err error
for i := 0; i < cfg.ConnRetry; i++ {
switch err = pinger(ctx); {
case err == nil:
return nil
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return errors.Wrap(err, "ping database connection")
default:
zap.L().Error("failed to ping database connection", zap.Error(err))
}
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "ping database connection")
case <-ticker.C:
}
}
return errors.Wrap(err, "ping database connection")
}
func (d *PGX) Close() error {
d.pool.Close()
return nil
}
func (d *PGX) DB() *sql.DB {
return stdlib.OpenDBFromPool(d.pool)
}
func (d *PGX) CollectMetrics(ctx context.Context) error {
errChan := make(chan error, 1)
defer close(errChan)
go func() {
s := d.pool.Stat()
totalConns.Set(float64(s.TotalConns()))
acquireCount.Set(float64(s.AcquireCount()))
acquiredCount.Set(float64(s.AcquiredConns()))
canceledAcquireCount.Set(float64(s.CanceledAcquireCount()))
constructingConns.Set(float64(s.ConstructingConns()))
emptyAcquireCount.Set(float64(s.EmptyAcquireCount()))
idleConns.Set(float64(s.IdleConns()))
maxConns.Set(float64(s.MaxConns()))
newConnsCount.Set(float64(s.NewConnsCount()))
maxLifetimeDestroyCount.Set(float64(s.MaxLifetimeDestroyCount()))
maxIdleDestroyCount.Set(float64(s.MaxIdleDestroyCount()))
errChan <- nil
}()
select {
case err := <-errChan:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (d *PGX) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
var (
fn = funcName()
t = time.Now()
rows, err = d.pool.Query(ctx, sql, args...)
since = time.Since(t)
)
queryTimeObserver(fn).Observe(since.Seconds())
if err != nil {
zap.L().Error("failure to execute query", zap.Error(errors.WithStack(err)))
return nil, err
}
return &RowsWithMetrics{
Rows: rows,
callerFnName: fn,
}, nil
}
func (d *PGX) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
var (
fn = funcName()
t = time.Now()
row = d.pool.QueryRow(ctx, sql, args...)
since = time.Since(t)
)
queryTimeObserver(fn).Observe(since.Seconds())
return &RowWithMetrics{
Row: row,
callerFnName: fn,
}
}
func (d *PGX) Begin(ctx context.Context) (pgx.Tx, error) {
var (
fn = funcName()
t = time.Now()
tx, err = d.pool.Begin(ctx)
)
if err != nil {
zap.L().Error("failure to begin a transaction", zap.Error(errors.WithStack(err)))
return nil, err
}
return &TxWithMetrics{
Tx: tx,
callerFnName: fn,
start: t,
}, nil
}
const (
skipN = 2
delimiter = "."
)
func funcName() string {
pc, _, _, _ := runtime.Caller(skipN) // nolint:dogsled
name := runtime.FuncForPC(pc).Name()
result := strings.Split(name, delimiter)
return result[len(result)-1]
}
package db
import (
"github.com/jackc/pgx/v5"
"github.com/pkg/errors"
"go.uber.org/zap"
)
type (
RowsWithMetrics struct {
pgx.Rows
callerFnName string
}
RowWithMetrics struct {
pgx.Row
callerFnName string
}
)
func (r *RowsWithMetrics) Next() bool {
ok := r.Rows.Next()
return ok
}
func (r *RowsWithMetrics) Scan(args ...any) error {
err := r.Rows.Scan(args...)
if err != nil {
zap.L().Error("failure to scan into the provided args", zap.Error(errors.WithStack(err)))
}
return err
}
func (r *RowsWithMetrics) Values() ([]interface{}, error) {
result, err := r.Rows.Values()
if err != nil {
zap.L().Error("failure to get row values", zap.Error(errors.WithStack(err)))
}
return result, err
}
func (r *RowsWithMetrics) Err() error {
err := r.Rows.Err()
if err != nil {
zap.L().Error("failure to iterate over returned rows", zap.Error(errors.WithStack(err)))
}
return err
}
func (r *RowWithMetrics) Scan(args ...any) error {
err := r.Row.Scan(args...)
if err != nil {
zap.L().Error("failure to scan into the provided args", zap.Error(errors.WithStack(err)))
}
return err
}
package db
import (
"context"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/pkg/errors"
"go.uber.org/zap"
)
type TxWithMetrics struct {
pgx.Tx
callerFnName string
start time.Time
}
func (t *TxWithMetrics) Commit(ctx context.Context) error {
var (
err = t.Tx.Commit(ctx)
since = time.Since(t.start)
)
if err != nil {
zap.L().Error("failure to commit a transaction", zap.Error(errors.WithStack(err)))
}
queryTimeObserver(t.callerFnName).Observe(since.Seconds())
return err
}
func (t *TxWithMetrics) Rollback(ctx context.Context) error {
var (
err = t.Tx.Rollback(ctx)
since = time.Since(t.start)
)
if err != nil && !errors.Is(err, pgx.ErrTxClosed) {
zap.L().Error("failure to rollback a transaction", zap.Error(errors.WithStack(err)))
}
queryTimeObserver(t.callerFnName).Observe(since.Seconds())
return nil
}
func (t *TxWithMetrics) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
result, err := t.Tx.Exec(ctx, sql, args...)
if err != nil {
zap.L().Error("failure to execute query", zap.Error(errors.WithStack(err)))
return pgconn.CommandTag{}, err
}
return result, nil
}
func (t *TxWithMetrics) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
rows, err := t.Tx.Query(ctx, sql, args...)
if err != nil {
zap.L().Error("failure to execute query", zap.Error(errors.WithStack(err)))
return nil, err
}
return &RowsWithMetrics{
Rows: rows,
callerFnName: t.callerFnName,
}, nil
}
func (t *TxWithMetrics) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
return &RowWithMetrics{
Row: t.Tx.QueryRow(ctx, sql, args...),
callerFnName: t.callerFnName,
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment