Created
September 22, 2023 19:08
-
-
Save peterroelants/7cbb0deb2ad632e000a42b47052c2dff to your computer and use it in GitHub Desktop.
FastAPI example web app for OAuth2 Authorization Code Flow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
FastAPI example web app for OAuth2 Authorization Code Flow. | |
Run with: | |
``` | |
uvicorn main:app --reload --log-config=log_conf.yaml | |
``` | |
""" | |
import logging | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, RedirectResponse | |
from starlette.middleware.sessions import SessionMiddleware | |
import jinja2 | |
import aiohttp | |
import urllib | |
import urllib.parse | |
import jwt | |
import secrets | |
logger = logging.getLogger(__name__) | |
AUTHORIZATION_SERVER_WELL_KNOWN_URL = "" # .well-known URL of authorization server | |
CLIENT_ID = "" # client-id | |
HOME_TEMPLATE = r""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="utf-8"/> | |
<title>Authorization Code Flow FastAPI Example Web app</title> | |
</head> | |
<body> | |
<h1>Authorization Code Flow FastAPI Example Web app</h1> | |
{% if user_name %} | |
<h2>Welcome {{user_name}}!</h2> | |
<p><a href="/logout">Logout</a></p> | |
{% else %} | |
<h2>Welcome Guest</h2> | |
<p><a href="/login">Login</a></p> | |
{% endif %} | |
</body> | |
</html> | |
""".strip() | |
app = FastAPI() | |
@app.get("/", response_class=HTMLResponse) | |
async def home(request: Request): | |
"""Home page displays user name if logged in, otherwise displays login link.""" | |
logger.debug(f"Home: {request.url=}") | |
user_info = request.session.get("user", {}) | |
user_name = user_info.get("user_name", None) | |
html = ( | |
jinja2.Environment(loader=jinja2.BaseLoader()) | |
.from_string(HOME_TEMPLATE) | |
.render( | |
user_name=user_name, | |
) | |
) | |
return html | |
@app.get("/login") | |
async def login(request: Request): | |
"""Redirects to OAuth2 Authorization server login page.""" | |
logger.debug(f"Login: {request.url=}") | |
# Get Authorization service information | |
auth_server_info: dict = await _get_well_known_oidc_services( | |
AUTHORIZATION_SERVER_WELL_KNOWN_URL | |
) | |
# Get Authorization code | |
auth_server_url = auth_server_info["authorization_endpoint"] | |
callback_url = urllib.parse.quote(str(request.url_for("callback")), safe="") | |
auth_url = f"{auth_server_url}?client_id={CLIENT_ID}&redirect_uri={callback_url}&response_type=code" | |
logger.debug(f"Redirecting to auth_url: {auth_url}") | |
return RedirectResponse(auth_url) | |
@app.get("/logout") | |
async def logout(request: Request): | |
"""Redirects to OAuth2 Authorization server logout page.""" | |
logger.debug(f"Logout: {request.url=}") | |
if "user" not in request.session: | |
logger.debug("User not logged in. Redirecting to home.") | |
return RedirectResponse(request.url_for("home")) | |
# Get Authorization service information | |
auth_server_info: dict = await _get_well_known_oidc_services( | |
AUTHORIZATION_SERVER_WELL_KNOWN_URL | |
) | |
end_session_endpoint = auth_server_info["end_session_endpoint"] | |
logger.info(f"Redirecting to end_session_endpoint: {end_session_endpoint}") | |
request.session.clear() | |
return RedirectResponse(end_session_endpoint) | |
@app.get("/callback") | |
async def callback(session_state: str, code: str, request: Request): | |
"""Callback URL for OAuth2 Authorization server.""" | |
logger.debug(f"Callback: {request.url=}") | |
# Get Authorization service information | |
auth_server_info: dict = await _get_well_known_oidc_services( | |
AUTHORIZATION_SERVER_WELL_KNOWN_URL | |
) | |
token_endpoint = auth_server_info["token_endpoint"] | |
# Get authorization tokens | |
data = { | |
"grant_type": "authorization_code", | |
"code": code, | |
"redirect_uri": str(request.url_for("callback")), | |
"client_id": CLIENT_ID, | |
} | |
# Exchange code for tokens | |
async with aiohttp.ClientSession() as session: | |
async with session.post(token_endpoint, data=data) as response: | |
response.raise_for_status() | |
token_dct = await response.json() | |
# Decode JWT token to get user info | |
jwt_dct = jwt.decode(token_dct["access_token"], options={"verify_signature": False}) | |
user_info = { | |
"user_name": jwt_dct.get("email") | |
or jwt_dct.get("preferred_username") | |
or jwt_dct.get("name"), | |
"tokens": token_dct, | |
} | |
# Update session | |
request.session["user"] = user_info | |
return RedirectResponse(request.url_for("home")) | |
async def _get_well_known_oidc_services( | |
authorization_server_well_known_url: str, | |
) -> dict: | |
""" | |
Returns OpenID Connect "Discovery document" of the given authorization server. | |
""" | |
logger.debug( | |
f"Contacting OIDC well-known URL at {authorization_server_well_known_url!r}." | |
) | |
# Get JSON document | |
async with aiohttp.ClientSession() as session: | |
async with session.get(authorization_server_well_known_url) as response: | |
response.raise_for_status() | |
return await response.json() | |
app.add_middleware(SessionMiddleware, secret_key=secrets.token_bytes(32).hex()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment