Created
February 16, 2022 14:13
-
-
Save alexrashed/32609fa4af9a0351a0ba07f12c5fd6ce to your computer and use it in GitHub Desktop.
Script to test assumptions on the requestURI overloading for rest-* protocols in AWS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from typing import List | |
import logging | |
from botocore.model import OperationModel | |
from localstack.aws.spec import list_services | |
def resolvable(operation_models: List[OperationModel]): | |
# Assumption: | |
# -> We can find the proper implementation by using: | |
# - the required _query_ and _header_ members | |
# - the deprecation notice of the operation (prefer non-deprecated) | |
non_deprecated_operation_models = [operation_model for operation_model in operation_models if not operation_model.deprecated] | |
# Check if the operation is overloaded and the distinction _cannot_ be done by the required _query_ fields: | |
input_shapes_per_operation = [operation_model.input_shape for operation_model in non_deprecated_operation_models] | |
required_location_fields_per_operation = [] | |
for input_shape in input_shapes_per_operation: | |
required_location_fields = [] | |
for required_member_name in input_shape.required_members: | |
required_member = input_shape.members[required_member_name] | |
if required_member.serialization.get("location") in ["querystring", "header"]: | |
required_location_fields.append(required_member.name) | |
required_location_fields_per_operation.append(required_location_fields) | |
duplicates = [required for required in required_location_fields_per_operation if required_location_fields_per_operation.count(required) > 1] | |
return len(duplicates) == 0 | |
def main(): | |
logging.basicConfig(level=logging.INFO) | |
for service in list_services(): | |
# We only do pattern matching on the requestURI for rest protocols | |
if service.protocol not in ["rest-xml", "rest-json"]: | |
continue | |
# Analyze service by service | |
patterns_per_method = defaultdict(lambda: defaultdict(list)) | |
# find all request URIs | |
for operation_name in service.operation_names: | |
operation_model = service.operation_model(operation_name) | |
method: str = operation_model.http.get("method") | |
request_uri: str = operation_model.http.get("requestUri") | |
patterns_per_method[method][request_uri].append(operation_model) | |
# find duplicates per method | |
duplicate_patterns_per_method = {} | |
for method, patterns in patterns_per_method.items(): | |
duplicates = {key: value for (key, value) in patterns.items() if len(value) > 1} | |
if duplicates: | |
duplicate_patterns_per_method[method] = duplicates | |
if duplicate_patterns_per_method: | |
logging.info(f"-------------- {service.service_name} --------------") | |
for method, patterns in duplicate_patterns_per_method.items(): | |
logging.info(f"-> {method}") | |
for pattern, operation_models in patterns.items(): | |
logging.info(f"└-> {operation_models}") | |
if not resolvable(operation_models): | |
logging.error(f" └-> Non-resolvable conflict found: {operation_models}") | |
else: | |
logging.info(" └-> Conflict resolved.") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment