Created
January 4, 2018 00:45
-
-
Save Quentin-M/b8a6aa1742afab260fec733dc141d788 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"context" | |
"crypto/ecdsa" | |
"crypto/elliptic" | |
crand "crypto/rand" | |
"crypto/tls" | |
"crypto/x509" | |
"crypto/x509/pkix" | |
"encoding/pem" | |
"errors" | |
"fmt" | |
"github.com/spf13/viper" | |
"log" | |
"math/big" | |
"math/rand" | |
"net/http" | |
"net/http/httputil" | |
"net/url" | |
"os" | |
"os/signal" | |
"strings" | |
"syscall" | |
"time" | |
) | |
func init() { | |
rand.Seed(time.Now().UTC().UnixNano()) | |
viper.SetConfigName("bearer-rproxy") | |
viper.AddConfigPath(".") | |
viper.AddConfigPath("/etc/") | |
viper.SetEnvPrefix("BRP") | |
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_")) | |
viper.AutomaticEnv() | |
viper.SetDefault("listen.address", "0.0.0.0:9010") | |
viper.SetDefault("listen.auto-tls", "true") | |
viper.SetDefault("listen.graceful-timeout", "5s") | |
viper.SetDefault("backend.request-timeout", "5s") | |
} | |
func main() { | |
// Build server. | |
s := &http.Server{ | |
Addr: viper.GetString("listen.address"), | |
Handler: handlersChain(), | |
TLSConfig: tlsConfig(), | |
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), | |
} | |
// Listen and Serve. | |
go func() { | |
log.Printf("listening on %s (TLS: %v)\n", s.Addr, s.TLSConfig != nil) | |
if s.TLSConfig != nil { | |
if err := s.ListenAndServeTLS("", ""); err != http.ErrServerClosed { | |
log.Fatalf("failed to listen/serve: %v", err) | |
} | |
} else { | |
if err := s.ListenAndServe(); err != http.ErrServerClosed { | |
log.Fatalf("failed to listen/serve: %v", err) | |
} | |
} | |
}() | |
// Wait for SIGTERM signal, to attempt a graceful shutdown. | |
waitSigterm() | |
ctx, cancel := context.WithTimeout(context.Background(), viper.GetDuration("listen.graceful-timeout")) | |
defer cancel() | |
log.Printf("stopping gracefully (%v)\n", viper.GetDuration("listen.graceful-timeout")) | |
if err := s.Shutdown(ctx); err != nil { | |
log.Printf("failed to stop gracefully: %v\n", err) | |
} | |
} | |
func selfSignedCertificate() (tls.Certificate, error) { | |
priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader) | |
if err != nil { | |
return tls.Certificate{}, fmt.Errorf("failed to create pub/priv key pair: %v", err) | |
} | |
privDERBytes, err := x509.MarshalECPrivateKey(priv) | |
if err != nil { | |
return tls.Certificate{}, fmt.Errorf("failed to marshal pub/priv key pair in DER: %v", err) | |
} | |
snLimit := new(big.Int).Lsh(big.NewInt(1), 128) | |
sn, err := crand.Int(crand.Reader, snLimit) | |
if err != nil { | |
return tls.Certificate{}, fmt.Errorf("failed to generate serial number: %v", err) | |
} | |
certTmpl := x509.Certificate{ | |
SerialNumber: sn, | |
Subject: pkix.Name{ | |
Organization: []string{"BitMEX"}, | |
OrganizationalUnit: []string{"Bearer RProxy"}, | |
}, | |
Issuer: pkix.Name{ | |
Organization: []string{"BitMEX"}, | |
OrganizationalUnit: []string{"Bearer RProxy"}, | |
}, | |
NotBefore: time.Now(), | |
NotAfter: time.Now().AddDate(10, 0, 0), | |
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | |
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, | |
BasicConstraintsValid: true, | |
} | |
// TODO: IP Addresses / DNS Names | |
certDERBytes, err := x509.CreateCertificate(crand.Reader, &certTmpl, &certTmpl, &priv.PublicKey, priv) | |
if err != nil { | |
return tls.Certificate{}, fmt.Errorf("failed to generate certificate: %v", err) | |
} | |
var pubPEM, privPEM bytes.Buffer | |
pem.Encode(&pubPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certDERBytes}) | |
pem.Encode(&privPEM, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privDERBytes}) | |
cert, err := tls.X509KeyPair(pubPEM.Bytes(), privPEM.Bytes()) | |
if err != nil { | |
return tls.Certificate{}, fmt.Errorf("failed to parse generated certificate: %v", err) | |
} | |
return cert, nil | |
} | |
func tlsConfig() *tls.Config { | |
if !viper.GetBool("listen.auto-tls") { | |
log.Print("auto-tls not specified, tls is disabled (plain-text bearer!)") | |
return nil | |
} | |
// Generate a self-signed certificate. | |
certificate, err := selfSignedCertificate() | |
if err != nil { | |
log.Fatalf("failed to create self-signed certificate: %v", err) | |
} | |
// Create and return an associated TLS configuration. | |
return &tls.Config{ | |
Certificates: []tls.Certificate{certificate}, | |
MinVersion: tls.VersionTLS12, | |
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, | |
PreferServerCipherSuites: true, | |
CipherSuites: []uint16{ | |
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, | |
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, | |
tls.TLS_RSA_WITH_AES_256_GCM_SHA384, | |
tls.TLS_RSA_WITH_AES_256_CBC_SHA, | |
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, // cURL 7.54, LibreSSL/2.0.20 | |
}, | |
} | |
} | |
func handlersChain() (h http.Handler) { | |
h = rproxyHandler() | |
h = http.TimeoutHandler(h, viper.GetDuration("backend.request-timeout"), "backend timed-out / failed") | |
h = bearerHandler(h) | |
return | |
} | |
func parseRoutes() (map[string]*url.URL, error) { | |
r := viper.GetStringSlice("backend.routes") | |
if len(r) == 0 { | |
return nil, errors.New("no routes specified, exiting") | |
} | |
routes := make(map[string]*url.URL) | |
for _, rd := range r { | |
rds := strings.Split(rd, "=") | |
if len(rds) != 2 { | |
return routes, fmt.Errorf("could not parse route %q: format is <route name>=<route URL>", rd) | |
} | |
rTargetURL, err := url.Parse(rds[1]) | |
if err != nil { | |
return routes, fmt.Errorf("could not parse route URL %q: %v", rd, err) | |
} | |
if strings.Contains(rds[0], "/") { | |
return routes, fmt.Errorf("could not parse route %q: route name cannot contain /", rd) | |
} | |
routes[rds[0]] = rTargetURL | |
log.Printf("loaded route: %v -> %v\n", rds[0], rds[1]) | |
} | |
return routes, nil | |
} | |
func rproxyHandler() http.Handler { | |
routes, err := parseRoutes() | |
if err != nil { | |
log.Fatal(err) | |
} | |
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | |
// Find target based on first section of request's path. | |
sPath := strings.Split(req.URL.Path, "/") | |
if len(sPath) <= 1 { | |
http.Error(rw, "no route target specified", http.StatusBadGateway) | |
return | |
} | |
target, ok := routes[sPath[1]] | |
if !ok { | |
http.Error(rw, "unknown route target", http.StatusBadGateway) | |
return | |
} | |
// Reverse proxy the request. | |
rp := &httputil.ReverseProxy{ | |
Director: func(req *http.Request) { | |
req.URL.Scheme = target.Scheme | |
req.URL.Host = target.Host | |
req.Host = req.URL.Host | |
req.URL.Path = singleJoiningSlash(target.Path, strings.TrimPrefix(req.URL.Path, "/"+sPath[1])) | |
targetQuery := target.RawQuery | |
if targetQuery == "" || req.URL.RawQuery == "" { | |
req.URL.RawQuery = targetQuery + req.URL.RawQuery | |
} else { | |
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery | |
} | |
}, | |
} | |
rp.ServeHTTP(rw, req) | |
}) | |
} | |
func bearerHandler(h http.Handler) http.Handler { | |
bearer := viper.GetString("listen.token") | |
if len(bearer) == 0 { | |
log.Print("no bearer token specified, bearer auth is disabled") | |
} | |
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | |
if strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") != bearer { | |
http.Error(rw, "bearer token is invalid", http.StatusUnauthorized) | |
return | |
} | |
req.Header.Del("Authorization") | |
h.ServeHTTP(rw, req) | |
}) | |
} | |
func singleJoiningSlash(a, b string) string { | |
if len(b) == 0 { | |
return a | |
} | |
aslash := strings.HasSuffix(a, "/") | |
bslash := strings.HasPrefix(b, "/") | |
switch { | |
case aslash && bslash: | |
return a + b[1:] | |
case !aslash && !bslash: | |
return a + "/" + b | |
} | |
return a + b | |
} | |
func waitSigterm() { | |
stop := make(chan os.Signal, 1) | |
signal.Notify(stop, syscall.SIGTERM) | |
<-stop | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment