Last active
February 8, 2024 13:38
-
-
Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, Request, status | |
from fastapi.responses import PlainTextResponse | |
import jwt | |
import logging | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
@app.middleware("http") | |
async def auth_middleware(request: Request, call_next): | |
auth_header = request.headers.get("Authorization") | |
# Response for unauthorized requests | |
resp401 = PlainTextResponse("Unauthorized", status_code=status.HTTP_401_UNAUTHORIZED) | |
# Check if Authorization header is valid | |
if auth_header: | |
# Get the token from the header | |
token = auth_header.split("Bearer ")[1] | |
if not token: | |
return resp401 | |
try: | |
decoded_token = validate_token(token) | |
if decoded_token: | |
# Here you can do some authorization logic like checking scopes, roles, etc. | |
# But we don't, we just chain the request to the next middleware | |
response = await call_next(request) | |
return response | |
except Exception as e: | |
logger.error(f"ERROR: Problem validating token: {e}") | |
return resp401 | |
else: | |
return resp401 | |
def validate_token(token: str): | |
jwks_client = jwt.PyJWKClient( | |
# Magic URL you might want to put in a config file or constant | |
uri="https://login.microsoftonline.com/common/discovery/keys", | |
cache_jwk_set=True, | |
lifespan=600 | |
) | |
signing_key = jwks_client.get_signing_key_from_jwt(token) | |
return jwt.decode( | |
token, | |
signing_key.key, | |
# This is the algorithm that Azure AD uses and lots of other OIDC providers | |
algorithms=["RS256"], | |
# For your API, this will be the Application ID (GUID) of the client you have registered | |
audience="b79fbf4d-3ef9-4689-8143-76b194e85509", | |
) | |
# Just a simple endpoint to demonstrate the middleware | |
@app.get("/") | |
def read_root(): | |
return {"Hello": "World"} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment