Skip to content

Instantly share code, notes, and snippets.

@alexrashed
Created February 16, 2022 14:13
Show Gist options
  • Save alexrashed/32609fa4af9a0351a0ba07f12c5fd6ce to your computer and use it in GitHub Desktop.
Save alexrashed/32609fa4af9a0351a0ba07f12c5fd6ce to your computer and use it in GitHub Desktop.
Script to test assumptions on the requestURI overloading for rest-* protocols in AWS
from collections import defaultdict
from typing import List
import logging
from botocore.model import OperationModel
from localstack.aws.spec import list_services
def resolvable(operation_models: List[OperationModel]):
# Assumption:
# -> We can find the proper implementation by using:
# - the required _query_ and _header_ members
# - the deprecation notice of the operation (prefer non-deprecated)
non_deprecated_operation_models = [operation_model for operation_model in operation_models if not operation_model.deprecated]
# Check if the operation is overloaded and the distinction _cannot_ be done by the required _query_ fields:
input_shapes_per_operation = [operation_model.input_shape for operation_model in non_deprecated_operation_models]
required_location_fields_per_operation = []
for input_shape in input_shapes_per_operation:
required_location_fields = []
for required_member_name in input_shape.required_members:
required_member = input_shape.members[required_member_name]
if required_member.serialization.get("location") in ["querystring", "header"]:
required_location_fields.append(required_member.name)
required_location_fields_per_operation.append(required_location_fields)
duplicates = [required for required in required_location_fields_per_operation if required_location_fields_per_operation.count(required) > 1]
return len(duplicates) == 0
def main():
logging.basicConfig(level=logging.INFO)
for service in list_services():
# We only do pattern matching on the requestURI for rest protocols
if service.protocol not in ["rest-xml", "rest-json"]:
continue
# Analyze service by service
patterns_per_method = defaultdict(lambda: defaultdict(list))
# find all request URIs
for operation_name in service.operation_names:
operation_model = service.operation_model(operation_name)
method: str = operation_model.http.get("method")
request_uri: str = operation_model.http.get("requestUri")
patterns_per_method[method][request_uri].append(operation_model)
# find duplicates per method
duplicate_patterns_per_method = {}
for method, patterns in patterns_per_method.items():
duplicates = {key: value for (key, value) in patterns.items() if len(value) > 1}
if duplicates:
duplicate_patterns_per_method[method] = duplicates
if duplicate_patterns_per_method:
logging.info(f"-------------- {service.service_name} --------------")
for method, patterns in duplicate_patterns_per_method.items():
logging.info(f"-> {method}")
for pattern, operation_models in patterns.items():
logging.info(f"└-> {operation_models}")
if not resolvable(operation_models):
logging.error(f" └-> Non-resolvable conflict found: {operation_models}")
else:
logging.info(" └-> Conflict resolved.")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment