You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is Terraform and Python code to help you track AWS Batch Jobs without having to poll Batch's APIs.
It sets up this integration
AWS Batch --(job state change)--> EventBridge --(message)--> SQS
You can do it in pure Terraform, or a mix of Python/Terraform.
The Terraform shows how to do this as a per-region queue across all workflows.
The Python shows how to do this as a workflow-specific queue, matching only job state updates from AWS Batch jobs that have a specific parameter (workflow_id) set.
I set this up for Prefect, which sets workflow IDs as UUIDs. I didn't trust AWS name length limits so I base62 encode them to save characters.
create_subscription_for_batch_jobs creates the following things in AWS
SQS Queue batch-job-states-{encoded_id}
EventBridge rule (matching job state events) batch-job-states-{encoded_id}
An EventBridge target that connects the two
An SQS Queue Policy that permits the EventBridge rule to push events to the queue
and clean_up_subscription cleans them up
Usage:
importboto3importuuidfrombatchimportcreate_subscription_for_batch_jobs, clean_up_subscription, submit_sleep_job, wait_for_jobstatesworkflow_id=uuid.uuid4() # or get it from whereversession=boto3._get_default_session()
subscription=create_subscription_for_batch_jobs(id, session)
# Right now, this blocks forever, so# you'll have to figure out what you want to do with each# job state message, and leave the function at some point.wait_for_jobstates(subscription, session)
# later...clean_up_subscription(subscription, session)
batch.py
importasynciofromenumimportEnumimportjsonfromtypingimportCallable, Optionalimportuuidimportboto3fromdataclassesimportdataclassfrom .id_encodingimportencode_uuidfrommypy_boto3_sqs.service_resourceimportMessageasSqsMessagefromtypesimportSimpleNamespace@dataclass(frozen=True)classQueueData:
url: strarn: str@dataclass(frozen=True)classRuleData:
rule_name: strtarget_ids: list[str]
@dataclass(frozen=True)classSubscription:
queue: QueueDatarule: RuleDatadefcreate_subscription_for_batch_jobs(
workflow_id: uuid.UUID, aws: boto3.Session
) ->Subscription:
encoded_id=encode_uuid(workflow_id)
eventbridge=aws.client("events")
rule_name=f"batch-job-states-{encoded_id}"rule=eventbridge.put_rule(
Name=rule_name,
Description="match Batch Job States",
EventPattern=json.dumps(
{
"source": ["aws.batch"],
"detail-type": ["Batch Job State Change"],
"detail": {
"parameters": {"workflow_id": [encoded_id]},
},
}
),
Tags=[{"Key": "flow_run_id", "Value": encoded_id}],
)
queue_name=f"batch-job-states-{encoded_id}"queue=aws.resource("sqs").create_queue(
QueueName=queue_name, tags={"flow_run_id": encoded_id}
)
queue_arn=queue.attributes["QueueArn"]
policy= {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "SQSAccess",
"Effect": "Allow",
"Principal": {"Service": "events.amazonaws.com"},
"Action": "sqs:SendMessage",
"Resource": queue_arn,
"Condition": {"ArnEquals": {"aws:SourceArn": rule["RuleArn"]}},
}
],
}
queue.set_attributes(Attributes={"Policy": json.dumps(policy)})
eventbridge.put_targets(
Rule=rule_name,
EventBusName="default",
Targets=[
{
"Id": queue_name,
"Arn": queue_arn,
}
],
)
returnSubscription(
queue=QueueData(url=queue.url, arn=queue_arn),
rule=RuleData(rule_name=rule_name, target_ids=[queue_name]),
)
defclean_up_subscription(sub: Subscription, aws: boto3.Session) ->None:
eventbridge=aws.client("events")
eventbridge.remove_targets(Rule=sub.rule.rule_name, Ids=sub.rule.target_ids)
eventbridge.delete_rule(Name=sub.rule.rule_name)
sqs=aws.client("sqs")
sqs.delete_queue(QueueUrl=sub.queue.url)
# Example code for running an AWS Batch job# This just runs an alpine container that sleeps for some time## Key code: the part that sets the container's workflow_id parameter# for matching in the eventbridge rule.defsubmit_sleep_job(workflow_id: uuid.UUID, seconds: int, aws: boto3.Session) ->str:
encoded_id=encode_uuid(workflow_id)
batch_client=aws.client("batch")
job=batch_client.submit_job(
jobName=f"sleep-{seconds}",
jobQueue="primary_queue_shared_us_west_2_dev_630831553409",
jobDefinition="alpine",
containerOverrides={
"command": [
"/bin/sh",
"-c",
f"echo 'sleeping..' ; sleep {seconds} ; echo 'awake'",
]
},
tags={"workflow_id": encoded_id},
parameters={"workflow_id": encoded_id},
)
returnjob["jobId"]
defmatches_job(job_id: str) ->Callable[[SqsMessage], bool]:
defmatcher(message: SqsMessage) ->bool:
parsed_body=json.loads(message.body)
returnparsed_body["detail"]["jobId"] ==job_idreturnmatcherasyncdefwait_for_jobstates(sub: Subscription, aws: boto3.Session) ->None:
queue=aws.resource("sqs").Queue(sub.queue.url)
whileTrue:
awaitasyncio.sleep(1) # Let other code run between these long pollsmessages=queue.receive_messages(WaitTimeSeconds=20, MaxNumberOfMessages=10)
formessageinmessages:
job_state_message=json.loads(
message.body, object_hook=lambdad: SimpleNamespace(**d)
)
task_id: str=job_state_message.detail.parameters.task_idjob_state: str=job_state_message.detail.status# Do whatever you want with the job state here.# Check out# https://docs.aws.amazon.com/batch/latest/userguide/job_states.htmlqueue.delete_messages(
Entries=[
{"Id": message.message_id, "ReceiptHandle": message.receipt_handle}
]
)