Skip to content

Instantly share code, notes, and snippets.

@jxsl13
Created November 1, 2022 15:07
Show Gist options
  • Save jxsl13/70c237f84adaa637fe0591a6a128afc3 to your computer and use it in GitHub Desktop.
Save jxsl13/70c237f84adaa637fe0591a6a128afc3 to your computer and use it in GitHub Desktop.
oauth2 resty wrapper (auto token refresh)
package client
import (
"sync"
"github.com/go-resty/resty/v2"
)
var (
factoryOnce sync.Once
factory *ClientFactory
)
type Config struct {
Insecure bool
TokenUrl string
ClientId string
ClientSecret string
}
func Init(config Config) (err error) {
factoryOnce.Do(func() {
factory = NewClientFactory(config)
// directly fetch a token in order to see
// whether the initialization process was successful
// as well as to see whether the credentials are correct
_, err = factory.tokenProvider.UpdateToken()
})
return err
}
func newFactory() *ClientFactory {
factoryOnce.Do(func() {
panic("client factory not initialized")
})
return factory
}
func New() *resty.Client {
return newFactory().New()
}
package client
func NewUserCredentials(username, password string) map[string]string {
return map[string]string{
"grant_type": "password",
"client_id": "public",
"username": username,
"password": password,
}
}
func NewClientCredentials(clientID, clientSecret string) map[string]string {
return map[string]string{
"grant_type": "client_credentials",
"client_id": clientID,
"client_secret": clientSecret,
}
}
package client
// UnexpectedTokenResponse is returned by the auth/client package
// when the client fails to fetch a new JWT.
type UnexpectedTokenResponse struct {
Msg string
Err string
}
func (e UnexpectedTokenResponse) Error() string {
return e.Msg + ": " + e.Err
}
package client
import (
"crypto/tls"
"time"
"github.com/go-resty/resty/v2"
)
func NewClientFactory(config Config) *ClientFactory {
return &ClientFactory{
tokenProvider: getTokenProvider(config),
insecure: config.Insecure,
}
}
type ClientFactory struct {
tokenProvider *TokenProvider
insecure bool
}
func (cf *ClientFactory) UpdateToken() (*JWT, error) {
return cf.tokenProvider.UpdateToken()
}
// Ping can be use din order to check if a connection to the configured
// api can be properly established or not
func (cf *ClientFactory) Ping() error {
_, err := cf.tokenProvider.UpdateToken()
return err
}
func getTokenProvider(config Config) *TokenProvider {
service := NewTokenProvider(
config.TokenUrl,
NewClientCredentials(
config.ClientId,
config.ClientSecret,
),
)
service.WithTLSClientConfig(
&tls.Config{
InsecureSkipVerify: config.Insecure,
},
)
return service
}
func (cs *ClientFactory) New() *resty.Client {
return resty.New().
AddRetryCondition(NewUnauthorizedCondition(cs.tokenProvider)).
OnBeforeRequest(NewBearerTokenMiddleware(cs.tokenProvider)).
SetRetryCount(3).
SetRetryAfter(NewRetryAfter(10 * time.Second)).
SetTLSClientConfig(
&tls.Config{
InsecureSkipVerify: cs.insecure,
},
)
}
package client
import (
"net/http"
"strconv"
"time"
"github.com/go-resty/resty/v2"
)
func NewBearerTokenMiddleware(t *TokenProvider) resty.RequestMiddleware {
return func(_ *resty.Client, req *resty.Request) error {
jwt, err := t.Token()
if jwt != nil {
req.SetAuthToken(jwt.AccessToken)
}
return err
}
}
func NewUnauthorizedCondition(t *TokenProvider) resty.RetryConditionFunc {
return func(r *resty.Response, err error) bool {
if err != nil {
return true
}
if r != nil {
if r.StatusCode() == http.StatusUnauthorized {
t.UpdateToken()
}
return r.IsError() && r.StatusCode() != http.StatusForbidden
}
return true
}
}
func NewServiceUnavailableCondition() resty.RetryConditionFunc {
return func(r *resty.Response, err error) bool {
return r.StatusCode() == http.StatusServiceUnavailable
}
}
func NewRetryAfter(defaultSleep time.Duration) resty.RetryAfterFunc {
return func(_ *resty.Client, resp *resty.Response) (time.Duration, error) {
// see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
retryAfter := resp.Header().Get("Retry-After")
// check if header is present
if retryAfter != "" {
// try int convertion for "seconds only" format
retrySec, errRetrySec := strconv.Atoi(retryAfter)
if errRetrySec == nil {
if retrySec > -1 {
return time.Duration(retrySec) * time.Second, nil
}
} else {
// try absolute time conversion
nowTS := time.Now()
retryTS, errParse := time.Parse(time.RFC1123, retryAfter)
if errParse == nil && retryTS.After(nowTS) {
dur := retryTS.Sub(nowTS)
return dur, nil
}
}
}
// return default
return defaultSleep, nil
}
}
package client
import (
"crypto/tls"
"errors"
"sync"
"time"
"stoxdog/internal/logging"
"github.com/go-resty/resty/v2"
)
type JWT struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
ExpiresIn int `json:"expires_in"`
RefreshExpiresIn int `json:"refresh_expires_in"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
NotBeforePolicy int64 `json:"not-before-policy"`
SessionState string `json:"session_state"`
Scope string `json:"scope"`
}
type TokenProvider struct {
restyClient *resty.Client
tokenUrl string
credentials map[string]string
currentJWT *JWT
tokenExpireTime time.Time
mu sync.Mutex
logger logging.Logger
}
func NewTokenProvider(tokenUrl string, credentials map[string]string) *TokenProvider {
return &TokenProvider{
restyClient: resty.New(),
tokenUrl: tokenUrl,
credentials: credentials,
logger: &logging.NoOpLogger{},
}
}
func (t *TokenProvider) Token() (*JWT, error) {
if t == nil {
return nil, errors.New("Token service not initialized")
}
t.mu.Lock()
defer t.mu.Unlock()
if t.currentJWT == nil {
return t.updateToken()
}
if !t.tokenExpireTime.IsZero() && time.Now().After(t.tokenExpireTime) {
t.logger.Debug("Token expired. Updating token.")
return t.updateToken()
}
return t.currentJWT, nil
}
func (t *TokenProvider) UpdateToken() (*JWT, error) {
if t == nil {
return nil, errors.New("Token service not initialized")
}
t.mu.Lock()
defer t.mu.Unlock()
return t.updateToken()
}
func (t *TokenProvider) updateToken() (*JWT, error) {
resultJWT := JWT{}
resp, err := t.restyClient.R().
SetFormData(t.credentials).
SetHeader("Cache-Control", "no-cache").
SetResult(&resultJWT).
Post(t.tokenUrl)
if err != nil {
return nil, err
}
if resp.StatusCode()/100 != 2 {
return nil, UnexpectedTokenResponse{
Msg: "unexpected authentication error",
Err: string(resp.Body()),
}
}
t.currentJWT = &resultJWT
if t.currentJWT.ExpiresIn > 0 {
t.tokenExpireTime = time.Now().Add(time.Duration(t.currentJWT.ExpiresIn-30) * time.Second)
} else {
t.tokenExpireTime = time.Time{}
}
if t.currentJWT.NotBeforePolicy > 0 {
wait := time.Until(time.Unix(t.currentJWT.NotBeforePolicy, 0))
time.Sleep(wait)
}
return t.currentJWT, nil
}
func (t *TokenProvider) WithTLSClientConfig(config *tls.Config) *TokenProvider {
t.mu.Lock()
defer t.mu.Unlock()
t.restyClient.SetTLSClientConfig(config)
return t
}
func (t *TokenProvider) WithLogger(logger logging.Logger) *TokenProvider {
t.mu.Lock()
defer t.mu.Unlock()
t.logger = logger
t.restyClient.SetLogger(logger)
return t
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment