Skip to content

Instantly share code, notes, and snippets.

@torson
Last active May 3, 2024 16:18
Show Gist options
  • Save torson/51e9010f044db387b087db94725d6b78 to your computer and use it in GitHub Desktop.
Save torson/51e9010f044db387b087db94725d6b78 to your computer and use it in GitHub Desktop.
Script for doing AWS Assume Role with MFA and supporting multiple account contexts
#!/usr/bin/env python3
# Requires: pip install pymemcache cryptography boto3 python-dotenv keyring
# This script asks for MFA ARN, ROLE ARN to be assumed , and MFA
# It then creates a file .env where it sets the default and user-inputed configuration,
# assumes the role, caches the access key/secret/token values into the system's vault
# (or memcached, depending what you choose) and outputs the access key/secret/token envvars ,
# in such a format that you can call it with shell's 'eval'.
# Inside a shell you can use it like this:
# eval $(/path/aws-assume-role.py)
# This will set the AWS_* envvars into the current shell
# It will create the .env file inside the aws-assume-role.py folder with the configuration
# Add it as an alias to .bashrc or .zshrc:
# alias aws-assume-role='[[ $- == *i* ]] && eval $(/path/aws-assume-role.py -p $(pwd))'
# When aws-assume-role finishes, the AWS_* envvars are set inside the current Terminal shell
# That '[[ $- == *i* ]]' is there so it doesn't execute on shell startup but only inside
# an interactive shell
# The script also supports multiple environments/contexts outside the default one
# To set them up, create additional .env.<LABEL_X> files that contain:
# ADDITIONAL_ENV_FILE_DIR_PATH_<LABEL_Y>=/path_Y
# You then run the script with argument --run-from-path set to the path where you would
# like to set the context :
# eval $(/path/aws-assume-role.py --run-from-path $(pwd))
# The script then checks if the passed --run-from-path value is a subdirectory of the
# /path_Y value of ADDITIONAL_ENV_FILE_DIR_PATH_<LABEL_Y> . If it is it then sets/uses the
# .env file inside the /path_Y folder.
# This is useful when you use multiple AWS accounts for different stacks.
# Example:
# - you have various stacks scatered all over the place using the same AWS account
# which can be considered as the default Account
# - then you have additional contexts that use different AWS accounts
# /path1/account2/terraform/us-west-1/ec2
# /path1/account3/terraform/eu-central-1/ec2
# /path2/account4/scripts
# - you set this in .env.others (can be any label in place of "others")
# ADDITIONAL_ENV_FILE_DIR_PATH_ACCOUNT2=/path1/account2
# ADDITIONAL_ENV_FILE_DIR_PATH_ACCOUNT3=/path1/account3
# ADDITIONAL_ENV_FILE_DIR_PATH_ACCOUNT4=/path2/account4
# - now when you are anywhere (any subfolder) inside the /path1/account2 folder
# and run the script with --run-from-path /path1/account2/sub1/sub2 , it will
# match the /path1/account2 being parent of /path1/account2/sub1/sub2 and will
# use /path1/account2 as the context and create/use /path1/account2/.env
# - so the script automatically changes contexts depending on what is passed with
# --run-from-path . Most practical is to use it with --run-from-path $(pwd) ,
# so when it is run inside /path1/account2/terraform/us-west-1/ec2/terraform.sh,
# the script will use this context to assume the role - account2 in this case
# NOTE: The first time the script runs it will ask you to choose if you'd like
# to assume role (answer "y" , more secure) that was prepared for you or to
# use .aws/credentials file (answer "n" , not secure).
# If you chose "n" then OPERATION_MODE=use_credentials_file was set in .env file
# to save your choise and not ask you again.
# If you'd like to change that to doing assume role then either remove the
# .env file or set OPERATION_MODE=assume_role_with_mfa in the .env file
# The Role's assume policy should require MFA to make this approach effective:
# {
# "Version": "2012-10-17",
# "Statement": [
# {
# "Effect": "Allow",
# "Principal": {
# "AWS": "arn:aws:iam::99999999999:user/John"
# },
# "Action": "sts:AssumeRole",
# "Condition": {"Bool": {"aws:MultiFactorAuthPresent": "true"}}
# }
# ]
# }
## example content of .env file
# OPERATION_MODE=assume_role_with_mfa
# MFA_DEVICE_ARN=arn:aws:iam::000000000000:mfa/example-api
# ASSUME_ROLE_ARN=arn:aws:iam::000000000000:role/user-example-api-role
# CACHING_BACKEND_SERVICE=system-vault
# KEY=avsdgae5t4z63bupvo45nzhjwrtregemgnblsetkgnls
# SESSION_DURATION_SECONDS=7200
# OVERRIDE_CACHING=false
from pymemcache.client.base import Client as MemcacheClient
from cryptography.fernet import Fernet
import boto3
import sys
import time
import os
import glob
from dotenv import load_dotenv
import getpass
import keyring
import argparse
import re
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--run-from-path', help='path where this script was called from. This is used for matching the context')
parser.add_argument('-l', '--output-launchctl-commands', help='[MacOS] output launchctl setenv commands instead of export ones so that any app started from this point onwards - not necessarily from this Terminal - will use these envvars', action='store_true', default=False)
args = parser.parse_args()
# Default configuration
# ---
caching_backend_service = "system-vault" # system-vault | memcached (in case system vault interface is not functioning properly. You need to be running memcached)
memcached_server = ('localhost', 11211) # in case you chose "memcached" . It's still secure to an extent as the values stored are encrypted by this script, it's just that it's another dependency and potential attack vector
session_duration_seconds = 3600
# ---
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
# Extract the filename from the path
script_name = os.path.basename(script_path)
# Change the current working directory to the script's directory
os.chdir(script_dir)
caching_backend_service_enum_system_vault = "system-vault"
caching_backend_service_enum_memcached = "memcached"
keyring_label="default"
env_file_default = script_name+".env"
context_dir_path_default = "."
def load_env_files(directory_path=".", pattern=env_file_default+"*"):
"""
Load multiple .env files from a specified directory matching a pattern.
:param directory_path: The directory to search for .env files.
:param pattern: The pattern to match for .env files.".
"""
# Construct the full pattern
full_pattern = os.path.join(directory_path, pattern)
# Find all files in the directory matching the pattern
env_files = glob.glob(full_pattern)
# Load each matching .env file
for env_file in env_files:
load_dotenv(dotenv_path=env_file, override=True)
print(f"Loaded environment variables from: {env_file}", file=sys.stderr)
def env_get_key(key, env_path=env_file_default):
try:
with open(env_path, 'r') as file:
lines = file.readlines()
for line in lines:
if line.startswith(f"{key}="):
# Extract the key
return line.strip().split('=', 1)[1]
return False
except FileNotFoundError:
print("Encryption key file not found.", file=sys.stderr)
sys.exit(1)
except IOError as e:
print(f"Could not read encryption key file: {e}", file=sys.stderr)
sys.exit(1)
def env_key_exists(key, env_path=env_file_default):
try:
with open(env_path, 'r') as file:
lines = file.readlines()
for line in lines:
if line.startswith(f"{key}="):
return True
return False
except FileNotFoundError:
print("File not found.", file=sys.stderr)
sys.exit(1)
except IOError as e:
print(f"Could not read file: {e}", file=sys.stderr)
sys.exit(1)
def env_add_or_update(key, value, env_path=env_file_default, value_output=True):
env_updated = False
env_content = []
# Check if the .env file exists and read its content
if os.path.exists(env_path):
with open(env_path, 'r') as file:
lines = file.readlines()
for i, line in enumerate(env_content):
if line.startswith(f"{key}="):
env_content[i] = f"{key}={value}\n"
env_updated = True
break
# If key does not exist in the file, append it
if not env_updated:
env_content.append(f"{key}={value}\n")
# Write the updated content back to the .env file
with open(env_path, 'a') as file:
file.writelines(env_content)
print(f"{'Updated' if env_updated else 'Added'} {key}{' = ' if value_output else ''}{value if value_output else ''} in {env_path}", file=sys.stderr)
return False
def load_encryption_key(env_path=env_file_default):
# Attempt to load existing environment variables from .env file
if env_key_exists("KEY", env_path):
return env_get_key("KEY", env_path).encode()
# .env file does not exist or KEY not found in it; generate a new key
key = Fernet.generate_key()
# Append the new key to the .env file
env_add_or_update("KEY", key.decode(), env_path, value_output=False)
print(f"New encryption key generated and stored in {env_path} file.", file=sys.stderr)
return key
def to_boolean(value):
true_values = {'true', '1', 't', 'y', 'yes'}
return str(value).lower() in true_values
def assume_role(env_path=env_file_default, keyring_label=keyring_label):
print(f"assume_role_arn = {assume_role_arn} , mfa_device_arn = {mfa_device_arn} \n", end='', file=sys.stderr)
fernet = Fernet(load_encryption_key(env_path))
if not to_boolean(os.getenv('OVERRIDE_CACHING')):
if caching_backend_service == caching_backend_service_enum_system_vault:
# Check for expiration timestamp
expiration_timestamp = keyring.get_password('session_'+keyring_label, 'banana')
if expiration_timestamp and float(expiration_timestamp) > time.time():
# Credentials are valid, decrypt and return them
print("Credentials still valid\n", end='', file=sys.stderr)
sys.stderr.flush()
access_key_id = fernet.decrypt(keyring.get_password('session_'+keyring_label, 'lemon').encode()).decode()
secret_access_key = fernet.decrypt(keyring.get_password('session_'+keyring_label, 'orange').encode()).decode()
session_token = fernet.decrypt(keyring.get_password('session_'+keyring_label, 'apple').encode()).decode()
return access_key_id, secret_access_key, session_token
elif caching_backend_service == caching_backend_service_enum_memcached:
memcache_client = MemcacheClient(memcached_server)
# Attempt to fetch credentials from Memcached
encrypted_access_key_id = memcache_client.get('lemon')
if encrypted_access_key_id:
access_key_id = fernet.decrypt(encrypted_access_key_id).decode()
secret_access_key = fernet.decrypt(memcache_client.get('orange')).decode()
session_token = fernet.decrypt(memcache_client.get('apple')).decode()
return access_key_id, secret_access_key, session_token
else:
print(f"ERROR: Caching backend service is not set!", file=sys.stderr)
sys.exit(1)
# Credentials are expired or missing; request new ones
print("MFA code: ", end='', file=sys.stderr)
sys.stderr.flush()
mfa_code = getpass.getpass(prompt='')
# envvar AWS_PROFILE needs to be set at this point if not 'default'
try:
sts_client = boto3.client('sts')
response = sts_client.assume_role(
RoleArn=assume_role_arn,
RoleSessionName="MySession",
SerialNumber=mfa_device_arn,
TokenCode=mfa_code,
DurationSeconds=session_duration_seconds
)
except Exception as e:
print(f"Error assuming role: {e}", file=sys.stderr)
sys.exit(1)
# Encrypt and store credentials with keyring
credentials = response['Credentials']
if caching_backend_service == caching_backend_service_enum_system_vault:
keyring.set_password('session_'+keyring_label, 'lemon', fernet.encrypt(credentials['AccessKeyId'].encode()).decode())
keyring.set_password('session_'+keyring_label, 'orange', fernet.encrypt(credentials['SecretAccessKey'].encode()).decode())
keyring.set_password('session_'+keyring_label, 'apple', fernet.encrypt(credentials['SessionToken'].encode()).decode())
# Store the expiration timestamp
expiration_time = time.time() + session_duration_seconds - 5
keyring.set_password('session_'+keyring_label, 'banana', str(expiration_time))
elif caching_backend_service == caching_backend_service_enum_memcached:
memcache_client.set('lemon', fernet.encrypt(credentials['AccessKeyId'].encode()), expire=session_duration_seconds-5)
memcache_client.set('orange', fernet.encrypt(credentials['SecretAccessKey'].encode()), expire=session_duration_seconds-5)
memcache_client.set('apple', fernet.encrypt(credentials['SessionToken'].encode()), expire=session_duration_seconds-5)
print(f'New secrets valid for {session_duration_seconds} seconds', file=sys.stderr)
return credentials['AccessKeyId'], credentials['SecretAccessKey'], credentials['SessionToken']
### START
# checking for context
load_env_files(pattern=env_file_default+".*")
env_path = env_file_default
context_dir_path = context_dir_path_default
if args.run_from_path:
run_from_path = args.run_from_path
if not run_from_path.endswith('/'):
run_from_path += '/'
for key, value in os.environ.items():
if key.startswith('ADDITIONAL_ENV_FILE_DIR_PATH_'):
env_dir_path = value
match = re.match(".+"+env_dir_path, run_from_path)
if match:
keyring_label = key.replace("ADDITIONAL_ENV_FILE_DIR_PATH_", "")
context_dir_path = match.group(0)
print(f"We're inside {env_dir_path} context \n", end='', file=sys.stderr)
env_path = context_dir_path + env_file_default
if not os.path.exists(env_path) or not env_key_exists("AWS_PROFILE", env_path):
print("Enter your aws-cli .aws/credentials profile for this context (AWS_PROFILE value): ", end='', file=sys.stderr)
sys.stderr.flush()
mfa_arn = input("")
env_add_or_update("AWS_PROFILE", mfa_arn, env_path)
break
# Check if the .env file already exists
if not os.path.exists(env_path) or not env_key_exists("OPERATION_MODE", env_path):
# .env file does not exist, prompt the user for their choice
print("Do you want to use 'Assume Role with MFA' mode of operation [y] or using the usual .aws/credentials file (not secure) [n]? ", end='', file=sys.stderr)
sys.stderr.flush()
user_choice = input("").strip().lower()
# Validate the user's choice and set the operation mode
if user_choice == "y":
operation_mode = 'assume_role_with_mfa'
elif user_choice == "n":
operation_mode = 'use_credentials_file'
else:
print("Invalid input. Exiting. ; ", file=sys.stderr)
sys.exit(1)
env_add_or_update("OPERATION_MODE", operation_mode, env_path)
if env_get_key("OPERATION_MODE", env_path) == "use_credentials_file":
print("Using usual .aws/credentials file (not secure)", file=sys.stderr)
sys.exit(0)
print("Doing Assume Role with MFA", file=sys.stderr)
if not env_key_exists("MFA_DEVICE_ARN", env_path):
print("Enter your MFA ARN: ", end='', file=sys.stderr)
sys.stderr.flush()
mfa_arn = input("")
env_add_or_update("MFA_DEVICE_ARN", mfa_arn, env_path)
if not env_key_exists("ASSUME_ROLE_ARN", env_path):
print("Enter the Role ARN you'll be assuming: ", end='', file=sys.stderr)
sys.stderr.flush()
mfa_arn = input("")
env_add_or_update("ASSUME_ROLE_ARN", mfa_arn, env_path)
if not env_key_exists("CACHING_BACKEND_SERVICE", env_path):
env_add_or_update("CACHING_BACKEND_SERVICE", caching_backend_service, env_path)
if not env_key_exists("SESSION_DURATION_SECONDS", env_path):
env_add_or_update("SESSION_DURATION_SECONDS", session_duration_seconds, env_path)
if not env_key_exists("OVERRIDE_CACHING", env_path):
env_add_or_update("OVERRIDE_CACHING", "false", env_path)
if caching_backend_service == caching_backend_service_enum_memcached:
if not env_key_exists("MEMCACHED_SERVER", env_path):
env_add_or_update("MEMCACHED_SERVER", memcached_server, env_path)
load_env_files(directory_path=context_dir_path, pattern=env_file_default)
mfa_device_arn = os.getenv('MFA_DEVICE_ARN')
assume_role_arn = os.getenv('ASSUME_ROLE_ARN')
session_duration_seconds = int(os.getenv('SESSION_DURATION_SECONDS'))
if session_duration_seconds < 900 :
print("ERROR: SESSION_DURATION_SECONDS must be at least 900", file=sys.stderr)
sys.exit(1)
if caching_backend_service == caching_backend_service_enum_memcached:
memcached_server = os.getenv('MEMCACHED_SERVER')
env_vars_to_clear = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_SESSION_TOKEN']
for var in env_vars_to_clear:
if var in os.environ:
del os.environ[var]
# print(f"env_path={env_path}, session_duration_seconds={session_duration_seconds}", file=sys.stderr)
access_key_id, secret_access_key, session_token = assume_role(env_path, keyring_label)
if args.output_launchctl_commands:
print(f'launchctl setenv AWS_ACCESS_KEY_ID "{access_key_id}" ; ')
print(f'launchctl setenv AWS_SECRET_ACCESS_KEY "{secret_access_key}" ; ')
print(f'launchctl setenv AWS_SESSION_TOKEN "{session_token}"')
else:
print(f'export AWS_ACCESS_KEY_ID="{access_key_id}" ; ')
print(f'export AWS_SECRET_ACCESS_KEY="{secret_access_key}" ; ')
print(f'export AWS_SESSION_TOKEN="{session_token}"')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment