Skip to content

Instantly share code, notes, and snippets.

@Melraidin
Created December 20, 2016 17:51
Show Gist options
  • Save Melraidin/7952fbafc4fd8aa0aaf8784d3bb70a3f to your computer and use it in GitHub Desktop.
Save Melraidin/7952fbafc4fd8aa0aaf8784d3bb70a3f to your computer and use it in GitHub Desktop.
Athena query base class for Luigi
"""
Base classes for Athena queries and loads.
"""
import datetime
import jaydebeapi;
import luigi
import luigi.s3
import sources
class AthenaQuery(luigi.Task):
"""
Base class for Athena queries.
Provides methods to execute a query against Athena. Typically the
only Luigi methods needed to be implemented will be requires() and
run(). Run will usually only require a call to query_store() to
generate its output at S3.
Note that the ".csv.metadata" key at S3 will be removed to allow
later loads to Redshift. This may impact viewing the query's
results in the Athena web UI.
"""
results_path_base = "s3://500px-emr/athena-results/"
resources = {"athena_query": 1}
date = luigi.DateParameter(default=datetime.date.today())
access_key = luigi.Parameter(default="",
config_path={"section": "s3", "name": "aws_access_key_id"})
secret_key = luigi.Parameter(default="",
config_path={"section": "s3", "name": "aws_secret_access_key"})
athena_jdbc_driver_path = luigi.Parameter(default="",
config_path={"section": "athena", "name": "jdbc_driver_path"})
def __init__(self, *args, **kwargs):
super(AthenaQuery, self).__init__(*args, **kwargs)
self.results_path = self.results_path_base + self.results_path_template % self.date
if self.results_path.endswith("/"):
self.results_path = self.results_path.rstrip("/")
def run(self):
raise NotImplementedError("must override run() method")
def _execute_query(self, sql):
props = {"s3_staging_dir": self.results_path,
"user": self.access_key,
"password": self.secret_key}
conn = jaydebeapi.connect(
"com.amazonaws.athena.jdbc.AthenaDriver",
["jdbc:awsathena://athena.us-east-1.amazonaws.com:443"],
jars=athena_jdbc_driver_path,
props=props)
rs = conn.execute(sql)
s3_client = luigi.s3.S3Client()
metadata = [k for k in s3_client.list(self.results_path)
if k.endswith(".csv.metadata")]
for k in metadata:
s3_client.remove("%s/%s" % (self.results_path, k))
return rs
def query_store(self, sql):
"""
Query Athena and close result set immediately. Results will be
at self.results_path.
"""
rs = self._execute_query(sql)
rs.close()
def query_retrieve(self, sql):
"""
Query Athena and return result set.
"""
rs = self._execute_query(sql)
while rs.next():
yield rs
def output(self):
return luigi.s3.S3Target(self.results_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment