|
import socket |
|
import ssl |
|
from itertools import chain |
|
|
|
import boto3 |
|
import click |
|
import dateutil.parser |
|
import polars as pl |
|
from cryptography import x509 |
|
|
|
|
|
def get_function_enis(function): |
|
"""Find all ENIs related to a VPC-attached Lambda function |
|
|
|
VPC-attached Lambda functions use an ENI for each combination of security groups and attached |
|
subnet. These ENIs can be shared across multiple functions, so we'll look across all of our |
|
function's subnets for Lambda VPC ENIs that exactly match our target function's set of security |
|
groups. |
|
""" |
|
lambda_client = boto3.client("lambda") |
|
ec2_client = boto3.client("ec2") |
|
|
|
func_config = lambda_client.get_function_configuration(FunctionName=function) |
|
paginator = ec2_client.get_paginator("describe_network_interfaces") |
|
enis = chain.from_iterable( |
|
page["NetworkInterfaces"] |
|
for page in paginator.paginate( |
|
Filters=[ |
|
{"Name": "description", "Values": ["AWS Lambda VPC ENI*"]}, |
|
{ |
|
"Name": "subnet-id", |
|
"Values": func_config["VpcConfig"]["SubnetIds"], |
|
}, |
|
*[ |
|
{"Name": "group-id", "Values": [group_id]} |
|
for group_id in func_config["VpcConfig"]["SecurityGroupIds"] |
|
], |
|
] |
|
) |
|
) |
|
return list(enis) |
|
|
|
|
|
def get_flow_logs(enis, start_dttm, end_dttm): |
|
logs_client = boto3.client("logs") |
|
|
|
flow_log_group = logs_client.describe_log_groups( |
|
logGroupNamePattern=f"/aws/vpc-flow-log/{enis[0]['VpcId']}" |
|
)["logGroups"][0]["logGroupName"] |
|
|
|
# Build an insights query focused on TCP traffic (protocol ID 6) |
|
# to/from our target ENIs during our target timespan |
|
query = f""" |
|
filter interfaceId in [{', '.join(f"'{eni['NetworkInterfaceId']}'" for eni in enis)}] |
|
and protocol = 6 |
|
| fields @timestamp, srcAddr, dstAddr, srcPort, dstPort, action, logStatus, start, end, protocol, packets, @logStream, @log |
|
| sort @timestamp desc |
|
""" |
|
|
|
response = logs_client.start_query( |
|
logGroupName=flow_log_group, |
|
startTime=int(start_dttm.timestamp()), |
|
endTime=int(end_dttm.timestamp()), |
|
queryString=query, |
|
limit=10000, |
|
) |
|
results = {"status": "Running"} |
|
while results["status"] == "Running": |
|
results = logs_client.get_query_results(queryId=response["queryId"]) |
|
|
|
return results |
|
|
|
|
|
def get_unpaired_requests(results): |
|
"""Summarize flow log TCP requests with no response |
|
|
|
For TCP traffic we expect to see request/response pairs in flow logs. Each request |
|
to a destination port (like 443 for HTTPS) will also have an arbitrary high-numbered |
|
source port. Consider a request that goes: |
|
|
|
1.2.3.4:10000 --> 5.6.7.8:443 |
|
|
|
We would expect to see a response that goes: |
|
5.6.7.8:443 --> 1.2.3.4:10000 |
|
|
|
And if we _don't_, it means something is either getting lost or the response |
|
comes outside of the log window we've captured. |
|
|
|
We can join our query result set to itself and only show rows where we see |
|
no response matching a request (an anti-join). Then count those no-responses |
|
by destination address/port to look for repeat offenders. |
|
""" |
|
|
|
# Munge query results into a Polars DataFrame |
|
df = pl.DataFrame( |
|
[{col["field"]: col["value"] for col in row} for row in results["results"]] |
|
) |
|
|
|
# Anti-join the dataframe to itself to find requests |
|
# with no matching response |
|
no_response = ( |
|
df.join( |
|
df, |
|
left_on=["srcAddr", "srcPort"], |
|
right_on=["dstAddr", "dstPort"], |
|
how="anti", |
|
) |
|
.groupby(["dstPort", "dstAddr"]) |
|
.agg(pl.count()) |
|
) |
|
|
|
# Add some context to our no_response destinations if |
|
# possible |
|
return no_response.with_columns( |
|
[ |
|
pl.struct(["dstAddr", "dstPort"]) |
|
.apply(lambda x: get_cert_subject(x["dstAddr"], x["dstPort"])) |
|
.alias("cert_subject"), |
|
pl.col("dstAddr").apply(reverse_lookup).alias("reverse_lookup"), |
|
] |
|
) |
|
|
|
|
|
def get_cert_subject(addr, port): |
|
"""Find the TLS certificate subject for a destination |
|
|
|
This is _usually_ only useful for HTTPS requests but |
|
worth a shot. Best-effort attempt with no real error |
|
handling. |
|
""" |
|
try: |
|
cert = ssl.get_server_certificate((addr, port), timeout=2) |
|
subject = x509.load_pem_x509_certificate(cert.encode("utf8")).subject |
|
except Exception: |
|
return '-' |
|
return subject.rfc4514_string() |
|
|
|
|
|
def reverse_lookup(addr): |
|
"""Try to find a hostname for a destination address |
|
|
|
If we can get some information from a reverse DNS |
|
lookup, that's usually more helpful than just the |
|
IP address. Best-effort attempt with no real error |
|
handling. |
|
""" |
|
try: |
|
response = socket.gethostbyaddr(addr) |
|
return ",".join([response[0], *response[1]]) |
|
except Exception: |
|
return '-' |
|
|
|
|
|
@click.command() |
|
@click.option("-f", "--function", help="A Lambda function name") |
|
@click.option( |
|
"-s", |
|
"--start-time", |
|
help="A start time in any format that dateutil supports (https://dateutil.readthedocs.io/en/stable/examples.html#parse-examples)", |
|
) |
|
@click.option( |
|
"-e", "--end-time", help="An end time in any format that dateutil supports" |
|
) |
|
def summarize_tcp_timeouts(function, start_time, end_time): |
|
start_dttm = dateutil.parser.parse(start_time) |
|
end_dttm = dateutil.parser.parse(end_time) |
|
|
|
func_enis = get_function_enis(function) |
|
flow_logs = get_flow_logs(func_enis, start_dttm, end_dttm) |
|
no_response = get_unpaired_requests(flow_logs) |
|
print(no_response) |
|
|
|
|
|
if __name__ == "__main__": |
|
summarize_tcp_timeouts() |
I see nothing.