Skip to content

Instantly share code, notes, and snippets.

@tiegz
Last active June 9, 2021 16:06
Show Gist options
  • Save tiegz/153176554fba0e38bb57abf2c1a33bce to your computer and use it in GitHub Desktop.
Save tiegz/153176554fba0e38bb57abf2c1a33bce to your computer and use it in GitHub Desktop.
mutual_tls_server.go
package main
import (
"crypto/tls"
"crypto/x509"
"flag"
"io"
"io/ioutil"
"log"
"net/http"
"strings"
)
// Runs an HTTPS server that does mutual TLS auth.
// Usage:
// go run main.go -ingressCert client.crt -hostCert server.crt -hostKey server.key
// curl -v \
// --cacert serverCert.pem \
// --key clientKey.pem --cert clientCert.pem \
// https://localhost:53434
//
func main() {
var (
flagIngressCertPath = flag.String("ingressCert", "", "Filepath to ingress client certificate (for mutual SSL)")
flagHostCertPath = flag.String("hostCert", "", "Filepath to this host's client cert (for mutual SSL)")
flagHostKeyPath = flag.String("hostKey", "", "Filepath to this host's client private key (for mutual SSL)")
)
flag.Parse()
// Load the client's cert into the cert pool
clientCert, err := ioutil.ReadFile(*flagIngressCertPath)
if err != nil {
log.Fatalf("Couldn't open client cert %s: %s\n", *flagIngressCertPath, err)
}
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(clientCert)
// Load the server's own cert
serverCert, err := tls.LoadX509KeyPair(*flagHostCertPath, *flagHostKeyPath)
if err != nil {
log.Fatalf("Server cert err: %s\n", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{serverCert}, // verify server
// This is required to get the client connection info before cert verification happens, in order to verify client's IP.
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
config := &tls.Config{
Certificates: []tls.Certificate{serverCert}, // verify server
ClientCAs: certPool, // verify client
ClientAuth: tls.RequireAndVerifyClientCert, // verify client
}
return config, nil
},
}
tlsConfig.BuildNameToCertificate()
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
headers := []string{}
for key, _ := range r.Header {
headers = append(headers, key)
}
log.Printf("%s %s ip=%s scheme=%s host=%s path=%s headers=%s\n\n", r.Method, r.RequestURI, r.RemoteAddr, r.URL.Scheme, r.Host, r.URL, strings.Join(headers, ","))
// Simulate a 502
if r.URL.Path == "/502" {
w.WriteHeader(http.StatusBadGateway)
io.WriteString(w, "Mutual TLS Server: 502 Bad Gateway\n")
} else {
w.WriteHeader(http.StatusCreated)
io.WriteString(w, "Mutual TLS Server: 200 OK\n")
}
})
server := &http.Server{
Addr: ":53434",
TLSConfig: tlsConfig,
Handler: mux,
}
log.Println("Listening on port 53434")
err = server.ListenAndServeTLS(*flagHostCertPath, *flagHostKeyPath)
if err != nil {
log.Fatalf("Error: %s\n", err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment