Skip to content

Instantly share code, notes, and snippets.

@samdbmg
Last active June 20, 2024 08:28
Show Gist options
  • Save samdbmg/23ffe8bfe0a30ca9072b282a94e9ac5d to your computer and use it in GitHub Desktop.
Save samdbmg/23ffe8bfe0a30ca9072b282a94e9ac5d to your computer and use it in GitHub Desktop.
Proxy a remote API secured using an OAuth2 client credentials grant, exposing it locally without auth
# Taken partly from https://stackoverflow.com/a/36601467
from datetime import datetime, timedelta
import os
from authlib.integrations.httpx_client import AsyncOAuth2Client
from sanic import Sanic
from sanic.log import logger
from sanic.response import raw
# This should be the full token endpoint URL,
# e.g. for Keycloak that's https://{KEYCLOAK_BASE_URL}/realms/{REALM_NAME}/protocol/openid-connect/token
TOKEN_URL = os.environ.get("TOKEN_URL")
CLIENT_ID = os.environ.get("CLIENT_ID")
CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
API_URL = os.environ.get("API_URL")
app = Sanic("ProxyApp")
@app.before_server_start
async def setup_client(app):
client = AsyncOAuth2Client(CLIENT_ID, CLIENT_SECRET)
app.ctx.client = client
app.ctx.token = await client.fetch_token(TOKEN_URL, grant_type="client_credentials")
logger.info("Token fetch complete")
async def _refresh_token():
token_expiry = datetime.fromtimestamp(app.ctx.token["expires_at"])
if token_expiry - timedelta(seconds=30) < datetime.now():
logger.info("Token expired or expiring soon, refreshing")
app.ctx.token = await app.ctx.client.fetch_token(TOKEN_URL, grant_type="client_credentials")
@app.route('/<path:path>', methods=["HEAD", "GET", "POST", "PUT", "DELETE"])
async def proxy_request(request, path):
if path == "":
target_url = API_URL
else:
target_url = f"{API_URL}/{path}"
logger.info(f"Proxying {request.method} request to /{path} -> {target_url}")
del (request.headers["host"])
# Suppress any HTML render, since the links won't work
del (request.headers["accept"])
# Refresh the token
await _refresh_token()
res = await app.ctx.client.request(
method=request.method,
headers=request.headers,
url=target_url,
content=request.body,
params=request.args
)
if res.status_code == 301 or res.status_code == 302:
# Rewrite the location header to avoid redirecting upstream
res.headers["Location"] = res.headers["Location"].replace(API_URL, "http://127.0.0.1:8000")
return raw(res.content, res.status_code, res.headers)
if __name__ == "__main__":
app.run(access_log=True)
aiohttp
authlib
sanic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment