Created
October 23, 2021 02:48
-
-
Save PedroMartinSteenstrup/ee9cbe6b54459e720ab334743f6d7d93 to your computer and use it in GitHub Desktop.
Basic implementation of Snowflake pass-through authentication in Superset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -------------------------------------- | |
# Snowflake OAuth | |
# -------------------------------------- | |
SNOWFLAKE_ACCOUNT = os.getenv('SNOWFLAKE_ACCOUNT') or None | |
if SNOWFLAKE_ACCOUNT: | |
SNOWFLAKE_SECURITY_INTEGRATION_CLIENT_ID = os.getenv('SNOWFLAKE_SECURITY_INTEGRATION_CLIENT_ID') | |
SNOWFLAKE_SECURITY_INTEGRATION_CLIENT_SECRET = os.getenv('SNOWFLAKE_SECURITY_INTEGRATION_CLIENT_SECRET') | |
SNOWFLAKE_OAUTH_AUTHORIZATION_ENDPOINT = f'https://{SNOWFLAKE_ACCOUNT}.snowflakecomputing.com/oauth/authorize' | |
SNOWFLAKE_OAUTH_TOKEN_ENDPOINT = f'https://{SNOWFLAKE_ACCOUNT}.snowflakecomputing.com/oauth/token-request' | |
# ------ Models | |
import datetime | |
import sys | |
import json | |
from flask import Blueprint, render_template, g, redirect | |
from flask_appbuilder import expose, Model | |
from flask_appbuilder.api import BaseApi, safe, protect | |
from superset.typing import FlaskResponse | |
def now(): | |
import datetime | |
return datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC') | |
class AccessToken(Model): | |
from sqlalchemy import Column, Integer, ForeignKey, DateTime, Text | |
__tablename__ = "access_token" | |
type = "table" | |
id = Column(Integer, primary_key=True) | |
access_token = Column(Text, unique=True, nullable=False) | |
refresh_token = Column(Text, unique=True, nullable=False) | |
expires_at = Column(DateTime, nullable=False) | |
refresh_expires_at = Column(DateTime, nullable=False) | |
created_at = Column(DateTime, nullable=False, default=now) | |
updated_at = Column(DateTime, nullable=True, default=now) | |
user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=False, unique=True) | |
def __repr__(self): | |
return repr(self.id) | |
# ---- DAO class | |
from superset.dao.base import BaseDAO | |
from superset.extensions import db | |
class AccessTokenDAO(BaseDAO): | |
model_cls = AccessToken | |
@classmethod | |
def find_by_user_id(cls, user_id: int) -> Optional[Model]: | |
return ( | |
db.session.query(cls.model_cls).filter_by(user_id=user_id).one_or_none() | |
) | |
# --- REST API | |
class ExternalAuthorizationAPI(BaseApi): | |
resource_name = 'external_auth' | |
class_permission_name = "external_auth_api" | |
method_permission_name = { | |
"get_redirect": "access", | |
} | |
@staticmethod | |
def _call_token_request_api(data): | |
import base64 | |
from flask import current_app | |
import requests | |
import logging | |
logging.getLogger("requests").setLevel(logging.DEBUG) | |
base_encoding = base64.b64encode('{}:{}'.format( | |
current_app.config["SNOWFLAKE_SECURITY_INTEGRATION_CLIENT_ID"], | |
current_app.config["SNOWFLAKE_SECURITY_INTEGRATION_CLIENT_SECRET"]).encode()).decode() | |
headers = { | |
'Authorization': 'Basic {}'.format(base_encoding), | |
'Content-type': 'application/x-www-form-urlencoded' | |
} | |
response_token = requests.post(current_app.config['SNOWFLAKE_OAUTH_TOKEN_ENDPOINT'], data=data, | |
headers=headers) | |
response_token.raise_for_status() | |
return response_token.json() | |
def refresh_token(self, access_token: AccessToken): | |
data = { | |
'grant_type': 'refresh_token', | |
'refresh_token': access_token.refresh_token | |
} | |
response = self._call_token_request_api(data) | |
expiry_date = datetime.datetime.utcnow() + datetime.timedelta(seconds=response['expires_in']) | |
try: | |
model = AccessTokenDAO.update(access_token, { | |
'access_token': response['access_token'], | |
'expires_at': expiry_date | |
}) | |
return model | |
except: | |
logging.exception('Could not refresh token for user %s', access_token.user_id, | |
exc_info=sys.exc_info()) | |
raise | |
def _handle_sf_authorization_code(self, code): | |
response = self._call_token_request_api({ | |
'grant_type': 'authorization_code', | |
'code': code, | |
'redirect_uri': f'{os.getenv("SUPERSET_DOMAIN")}/api/v1/external_auth/redirect', | |
}) | |
return response | |
@expose('/redirect', methods=["GET"]) | |
@protect(allow_browser_login=True) | |
@safe | |
def redirect(self) -> FlaskResponse: | |
from superset import appbuilder | |
from flask import request, flash | |
if not g.user or not g.user.get_id(): | |
return redirect(appbuilder.get_url_for_login) | |
if not request.args.get('code'): | |
return json.dumps({'error': 'empty authorization code'}), 400 | |
response = self._handle_sf_authorization_code(request.args['code']) | |
utc_now = datetime.datetime.utcnow() | |
expiry_date = utc_now + datetime.timedelta(seconds=response['expires_in']) | |
refresh_expiry_date = utc_now + datetime.timedelta(seconds=response['refresh_token_expires_in']) | |
if response['username'] != g.user.username: | |
flash('Error! User authenticated to Snowflake is not the same as currently logged in', 'danger') | |
return redirect('/snowflake_oauth') | |
ack = self.__create_token(access_token=response['access_token'], | |
refresh_token=response['refresh_token'], | |
expires_at=expiry_date, | |
refresh_expires_at=refresh_expiry_date, | |
user=g.user) | |
if ack: | |
flash('New OAuth token has been added', 'info') | |
return redirect("/superset/welcome/") | |
else: | |
flash('Something went wrong! Couldn\'t get an OAuth token', 'danger') | |
return redirect('/snowflake_oauth') | |
@staticmethod | |
def __create_token( | |
access_token: str, | |
expires_at: datetime.datetime, | |
refresh_expires_at: datetime.datetime, | |
refresh_token: str, | |
user): | |
import sys | |
try: | |
access_token_instance = AccessTokenDAO.find_by_user_id(user.get_id()) | |
if access_token_instance: | |
AccessTokenDAO.update(access_token_instance, | |
{ | |
'access_token': access_token, | |
'expires_at': expires_at, | |
'refresh_token': refresh_token, | |
'refresh_expires_at': refresh_expires_at, | |
'updated_at': now() | |
}) | |
logging.info('Update token for user %s', user.username) | |
else: | |
access_token_instance = AccessTokenDAO.create({ | |
'access_token': access_token, | |
'refresh_token': refresh_token, | |
'expires_at': expires_at, | |
'refresh_expires_at': refresh_expires_at, | |
'user_id': user.get_id(), | |
}) | |
logging.info('New token added for user %s', user.username) | |
return access_token_instance | |
except Exception: | |
logging.exception('Could not add/update token for user %s', user.username, exc_info=sys.exc_info()) | |
return False | |
# ---------------------------- | |
# Custom app initializer | |
# ---------------------------- | |
from superset.app import SupersetAppInitializer | |
class MySupersetAppInitializer(SupersetAppInitializer): | |
def init_views(self) -> None: | |
import logging | |
from superset.extensions import appbuilder | |
from flask_babel import gettext as __ | |
appbuilder.add_api(ExternalAuthorizationAPI) | |
appbuilder.add_link( | |
"Snowflake OAuth", | |
label="Snowflake OAuth", | |
href="/snowflake_oauth", | |
icon="fa-folder-open-o", | |
category="Data", | |
category_label=__("Data"), | |
) | |
super().init_views() | |
APP_INITIALIZER = MySupersetAppInitializer | |
# -------------------------- | |
# New flask blueprint | |
# --------------------------- | |
snowflake_oauth_bp = Blueprint( | |
'snowflake_oauth', | |
__name__, | |
template_folder='templates') | |
@snowflake_oauth_bp.route('/snowflake_oauth') | |
def snowflake_oauth(): | |
from superset import appbuilder | |
from urllib.parse import urlencode | |
from flask import current_app | |
from flask_login import current_user | |
if not current_user or not current_user.get_id(): | |
return redirect(appbuilder.get_url_for_login) | |
link_params = { | |
'client_id': current_app.config['SNOWFLAKE_SECURITY_INTEGRATION_CLIENT_ID'], | |
'redirect_uri': f'{os.getenv("SUPERSET_DOMAIN")}/api/v1/external_auth/redirect', | |
'response_type': 'code', | |
'scope': 'refresh_token' | |
} | |
query = urlencode(link_params) | |
auth_link = f"{current_app.config['SNOWFLAKE_OAUTH_AUTHORIZATION_ENDPOINT']}?{query}" | |
return render_template('snowflake_oauth.html', | |
appbuilder=appbuilder, | |
title='Snowflake OAuth', | |
snowflake_auth={ | |
'account': SNOWFLAKE_ACCOUNT, | |
'auth_link': auth_link | |
} | |
) | |
BLUEPRINTS = [snowflake_oauth_bp] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{% extends "superset/basic.html" %} | |
{% block title %} | |
{{ title }} | |
{% endblock %} | |
{% block body %} | |
<div id="app" class="container" xmlns="http://www.w3.org/1999/html"> | |
<div class="container"> | |
<div class="row"> | |
{% include "superset/flash_wrapper.html" %} | |
{% if snowflake_auth %} | |
<div class="container"> | |
<div class="mainbox col-md-6 col-md-offset-3 col-sm-8 col-sm-offset-2"> | |
<div class="panel panel-primary"> | |
<div class="panel-heading"> | |
<div class="panel-title">Snowflake OAuth</div> | |
</div> | |
<div class="panel-body" style="padding-top:30px"> | |
<div class="help-block">Click to login to Snowflake account <em>{{ snowflake_auth.account }}</em> | |
</div> | |
<div class="help-block">This connection will be valid for 8 hours.</div> | |
<br> | |
<a href="{{ snowflake_auth.auth_link }}" class="btn btn-primary">Connect</a> | |
</div> | |
</div> | |
</div> | |
</div> | |
{% endif %} | |
</div> | |
</div> | |
</div> | |
{% endblock %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
------------------------------------------------------ | |
-- Integration for OAuth in Superset prod | |
----------------------------------------------------- | |
-- With ACCOUNTADMIN role | |
CREATE SECURITY INTEGRATION superset_prod | |
TYPE = OAUTH | |
ENABLED = TRUE | |
OAUTH_CLIENT = CUSTOM | |
OAUTH_CLIENT_TYPE='CONFIDENTIAL' | |
oauth_issue_refresh_tokens=True | |
OAUTH_REFRESH_TOKEN_VALIDITY=28800 -- 8hs | |
OAUTH_REDIRECT_URI = 'https://superset.yourdomain.com/api/v1/external_auth/redirect' | |
OAUTH_ALLOW_NON_TLS_REDIRECT_URI=False | |
COMMENT = 'OAUTH with Snowflake from production superset'; | |
-- Get client ID and secret by running | |
SELECT SYSTEM$SHOW_OAUTH_CLIENT_SECRETS('SUPERSET_PROD'); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment