Created
November 5, 2018 22:16
-
-
Save pbrumblay/cb6edc3774c3d1a45605074b80a5797a to your computer and use it in GitHub Desktop.
Airflow custom Google Cloud Storage Hook with resumable uploads, partial downloads, and compose (everyone else calls it "concatenating") functionality
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from google.cloud import storage | |
from airflow.hooks.base_hook import BaseHook | |
from airflow.utils.log.logging_mixin import LoggingMixin | |
import random | |
import string | |
class GCSCustomHook(BaseHook, LoggingMixin): | |
def __init__(self, storage_conn_id='google_cloud_storage_default'): | |
self.storage_conn_id = storage_conn_id | |
self.conn = None | |
def get_conn(self): | |
""" | |
Returns a Google cloud storage object | |
""" | |
if self.conn is None: | |
params = self.get_connection(self.storage_conn_id) | |
project = params.extra_dejson.get('project') | |
self.log.info('Getting connection using project %s', project) | |
self.conn = storage.Client(project) | |
return self.conn | |
def list(self, bucket_name, prefix): | |
conn = self.get_conn() | |
bucket = conn.lookup_bucket(bucket_name) | |
if bucket is None: | |
raise ValueError('Could not find bucket %s' % bucket_name) | |
return bucket.list_blobs(prefix=prefix) | |
def compose(self, bucket_name, prefix, new_blob_name): | |
"""Recursively combine (aka "compose") blob shards in groups of 32""" | |
blobs_iterator = self.list(bucket_name, prefix) | |
source_blobs = [] | |
# list() returns an iterator. Get all the entries so we can count them. | |
for b in blobs_iterator: | |
source_blobs.append(b) | |
conn = self.get_conn() | |
bucket = conn.lookup_bucket(bucket_name) | |
if bucket is None: | |
raise ValueError('Could not find bucket %s' % bucket_name) | |
blob_content_type = None | |
# recursive base case, if there is only one with a given prefix, | |
# rename to the desired name | |
if len(source_blobs) == 1: | |
self.log.info("Found 1 blob matching prefix, renaming to: %s ", new_blob_name) | |
bucket.rename_blob(source_blobs[0], new_blob_name) | |
else: | |
# create a new prefix to compose blobs in groups of 32 into | |
random_name = ''.join(random.choice(string.ascii_lowercase) for _ in range(10)) | |
# group the blobs | |
i = 0 | |
group = -1 | |
list_of_lists = [] | |
for s in source_blobs: | |
if blob_content_type is None: | |
blob_content_type = s.content_type | |
if i % 32 == 0: | |
list_of_lists.append([]) | |
group = group + 1 | |
self.log.info("Adding blob to group [%s]: %s ", group, s.path) | |
list_of_lists[group].append(s) | |
i = i + 1 | |
# for each group, compose under the new name | |
k = 0 | |
for l in list_of_lists: | |
new_blob = bucket.blob(random_name + "-" + str(k)) | |
self.log.info("Creating blob: %s ", new_blob.path) | |
# workaround: https://github.com/googleapis/google-cloud-python/issues/5834 | |
new_blob.content_type = blob_content_type | |
new_blob.compose(l) | |
k = k + 1 | |
# delete all blobs with prefix to clean up | |
for s in source_blobs: | |
self.log.info("Deleting blob: %s", s.path) | |
s.delete() | |
# repeat the process using the new prefix | |
self.compose(bucket_name, random_name, new_blob_name) | |
""" | |
Use google cloud storage api to implement a resumable upload which does not require | |
the entire file to be written to disk before transmission - the api supports "file like objects" | |
which, when a chunk size is set can steam data from the source object into GCS. | |
Only supports using default security context. Cannot use / inherit from airflow GCS hooks since | |
they use the wrong (deprecated) oauth2 lib. | |
""" | |
def resumable_upload(self, file_object, bucket_name, blob_name): | |
""" | |
Returns a list of files on the remote system. | |
:param file_object: a file like object | |
:type file_object: io.IOBase | |
:param bucket_name: a GCS bucket | |
:type bucket_name: str | |
:param blob_name: the destination path (blob) | |
:type blob_name: str | |
""" | |
conn = self.get_conn() | |
bucket = conn.lookup_bucket(bucket_name) | |
if bucket is None: | |
raise ValueError('Could not find bucket ' % bucket_name) | |
self.log.info("Found bucket starting upload to gs://%s/%s", bucket_name, blob_name) | |
blob = bucket.blob(blob_name) | |
blob.chunk_size = 1024 * 1024 * 3 # 3mb | |
blob.upload_from_file(file_object) | |
def download_file_part(self, bucket_name, blob_name, start, end, file_name): | |
conn = self.get_conn() | |
bucket = conn.lookup_bucket(bucket_name) | |
if bucket is None: | |
raise ValueError('Could not find bucket ' % bucket_name) | |
self.log.info("Found bucket %s. Downloading file [%s] at %s to %s ", bucket_name, blob_name) | |
blob = bucket.blob(blob_name) | |
blob.download_to_filename(file_name, start=start, end=end) | |
def download_file_string(self, bucket_name, blob_name): | |
conn = self.get_conn() | |
bucket = conn.lookup_bucket(bucket_name) | |
if bucket is None: | |
raise ValueError('Could not find bucket ' % bucket_name) | |
self.log.info("Found bucket %s. Downloading file [%s] at %s to %s ", bucket_name, blob_name) | |
blob = bucket.blob(blob_name) | |
blob.download_as_string() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment