Skip to content

Instantly share code, notes, and snippets.

@lmazuel
Last active October 20, 2023 17:21
Show Gist options
  • Save lmazuel/54a7e3fb947ea5fa4580afea97f76a72 to your computer and use it in GitHub Desktop.
Save lmazuel/54a7e3fb947ea5fa4580afea97f76a72 to your computer and use it in GitHub Desktop.
HTTPx wrapper for azure-identity
import typing
import httpx
import requests
from httpx import Auth, Request, Response
from azure.core.credentials import TokenCredential
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline import PipelineRequest, PipelineContext
from azure.core.pipeline.transport import HttpRequest
from azure.identity import DefaultAzureCredential
class AzureWrapper:
def __init__(self, credential: TokenCredential = None, scopes: list[str] = None, **kwargs) -> None:
if scopes is None:
scopes = ["https://management.azure.com/.default"]
if credential is None:
credential = DefaultAzureCredential()
self._policy = BearerTokenCredentialPolicy(credential, scopes[0], **kwargs)
def _make_request(self):
return PipelineRequest(
HttpRequest(
"CredentialWrapper",
"https://fakeurl"
),
PipelineContext(None)
)
def _get_auth_header(self):
"""Ask the azure-core BearerTokenCredentialPolicy policy to get a token.
Using the policy gives us for free the caching system of azure-core.
We could make this code simpler by using private method, but by definition
I can't assure they will be there forever, so mocking a fake call to the policy
to extract the token, using 100% public API."""
request = self._make_request()
self._policy.on_request(request)
# Read Authorization, and get the second part after Bearer
return request.http_request.headers["Authorization"]
def __call__(self, request: Request) -> Request:
request.headers["Authorization"] = self._get_auth_header()
return request
if __name__ == "__main__":
import os
credentials = AzureWrapper(credential=DefaultAzureCredential())
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "<subscription_id>")
response = httpx.get("https://management.azure.com/subscriptions/{}/resourcegroups?api-version=2020-06-01".format(subscription_id), auth=credentials)
print(response.json())
response = requests.get("https://management.azure.com/subscriptions/{}/resourcegroups?api-version=2020-06-01".format(subscription_id), auth=credentials)
print(response.json())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment