Created
April 16, 2016 07:00
-
-
Save skiz/41190df62ea94dfcb3762edaaaa634a5 to your computer and use it in GitHub Desktop.
gin-gonic/contrib/sessions based CSRF with back button support, security enhancements, and code clarity.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Package csrf provides session based CSRF protection for the gin framework. | |
// This library is a heavily modified version of https://github.com/utrack/gin-csrf | |
// The primary differences of this library are that tokens are persisted for the | |
// life of the session (back button support), the secret is not stored in any | |
// form on the client side, and much of the code has been refactored for clarity. | |
package csrf | |
import ( | |
"crypto/sha1" | |
"encoding/base64" | |
"io" | |
"github.com/dchest/uniuri" | |
"github.com/gin-gonic/contrib/sessions" | |
"github.com/gin-gonic/gin" | |
) | |
const ( | |
csrfFormKey = "_csrf" | |
csrfSaltName = "csrfSalt" | |
csrfSecretName = "csrfSecret" | |
csrfTokenName = "csrfToken" | |
defaultSecret = "changeMeWithOptions" | |
) | |
var defaultIgnoreMethods = []string{"GET", "HEAD", "OPTIONS"} | |
var defaultErrorFunc = func(c *gin.Context) { | |
c.String(403, "CSRF token mismatch") | |
} | |
var defaultTokenGetter = func(c *gin.Context) string { | |
r := c.Request | |
if t := r.FormValue(csrfFormKey); len(t) > 0 { | |
return t | |
} else if t := r.URL.Query().Get(csrfFormKey); len(t) > 0 { | |
return t | |
} else if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 { | |
return t | |
} else if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 { | |
return t | |
} | |
return "" | |
} | |
// Options to configure the CSRF middleware. | |
type Options struct { | |
Secret string | |
IgnoreMethods []string | |
ErrorFunc gin.HandlerFunc | |
TokenGetter func(c *gin.Context) string | |
} | |
func hash(secret, salt string) string { | |
h := sha1.New() | |
io.WriteString(h, salt+"-"+secret) | |
hash := base64.URLEncoding.EncodeToString(h.Sum(nil)) | |
return hash | |
} | |
func inArray(arr []string, value string) bool { | |
for _, v := range arr { | |
if v == value { | |
return true | |
} | |
} | |
return false | |
} | |
// Middleware validates a session based CSRF token. | |
func Middleware(options Options) gin.HandlerFunc { | |
if options.Secret == "" { | |
options.Secret = defaultSecret | |
} | |
if options.IgnoreMethods == nil { | |
options.IgnoreMethods = defaultIgnoreMethods | |
} | |
if options.ErrorFunc == nil { | |
options.ErrorFunc = defaultErrorFunc | |
} | |
if options.TokenGetter == nil { | |
options.TokenGetter = defaultTokenGetter | |
} | |
return func(c *gin.Context) { | |
session := sessions.Default(c) | |
c.Set(csrfSecretName, options.Secret) | |
if inArray(options.IgnoreMethods, c.Request.Method) { | |
c.Next() | |
return | |
} | |
salt, ok := session.Get(csrfSaltName).(string) | |
if !ok || len(salt) == 0 { | |
options.ErrorFunc(c) | |
c.Abort() | |
return | |
} | |
token := options.TokenGetter(c) | |
if hash(options.Secret, salt) == token { | |
c.Next() | |
return | |
} | |
options.ErrorFunc(c) | |
c.Abort() | |
return | |
} | |
} | |
// GetToken returns the CSRF token for the session, or generates a new one. | |
func GetToken(c *gin.Context) string { | |
session := sessions.Default(c) | |
token, ok := session.Get(csrfTokenName).(string) | |
if ok && len(token) > 0 { | |
return token | |
} | |
return GenerateToken(c) | |
} | |
// GenerateToken sets and returns a new CSRF token for the current session. | |
func GenerateToken(c *gin.Context) string { | |
session := sessions.Default(c) | |
salt := uniuri.New() | |
token := hash(c.MustGet(csrfSecretName).(string), salt) | |
session.Set(csrfTokenName, token) | |
session.Set(csrfSaltName, salt) | |
session.Save() | |
return token | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment