Created
July 15, 2021 17:23
-
-
Save bnaul/4819f045ccbee160b60a530b6cfc0c98 to your computer and use it in GitHub Desktop.
Dask <-> BigQuery helpers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _stream_to_dfs(bqs_client, stream_name, schema, timeout): | |
"""Given a Storage API client and a stream name, yield all dataframes.""" | |
return [ | |
pyarrow.ipc.read_record_batch( | |
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch), schema | |
).to_pandas() | |
for message in bqs_client.read_rows(name=stream_name, offset=0, timeout=timeout) | |
] | |
@dask.delayed | |
def _read_rows_arrow( | |
*, | |
make_create_read_session_request: callable, | |
partition_field: str = None, | |
project_id: str, | |
stream_name: str = None, | |
timeout: int, | |
) -> pd.DataFrame: | |
"""Read a single batch of rows via BQ Storage API, in Arrow binary format. | |
Args: | |
project_id: BigQuery project | |
create_read_session_request: kwargs to pass to `bqs_client.create_read_session` as `request` | |
partition_field: BigQuery field for partitions, to be used as Dask index col for divisions | |
NOTE: Please set if specifying `row_restriction` filters in TableReadOptions. | |
stream_name: BigQuery Storage API Stream "name". | |
NOTE: Please set if reading from Storage API without any `row_restriction`. | |
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream | |
NOTE: `partition_field` and `stream_name` kwargs are mutually exclusive. | |
Adapted from https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py. | |
""" | |
with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client): | |
session = bqs_client.create_read_session(make_create_read_session_request()) | |
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(session.arrow_schema.serialized_schema)) | |
if (partition_field is not None) and (stream_name is not None): | |
raise ValueError( | |
"The kwargs `partition_field` and `stream_name` are mutually exclusive." | |
) | |
elif partition_field is not None: | |
shards = [ | |
df | |
for stream in session.streams | |
for df in _stream_to_dfs(bqs_client, stream.name, schema, timeout=timeout) | |
] | |
# NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list | |
if len(shards) == 0: | |
shards = [schema.empty_table().to_pandas()] | |
shards = [shard.set_index(partition_field, drop=True) for shard in shards] | |
elif stream_name is not None: | |
shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout) | |
# NOTE: BQ Storage API can return empty streams | |
if len(shards) == 0: | |
shards = [schema.empty_table().to_pandas()] | |
else: | |
raise NotImplementedError("Please specify either `partition_field` or `stream_name`.") | |
return pd.concat(shards) | |
def gbq_as_dask_df( | |
project_id: str, | |
dataset_id: str, | |
table_id: str, | |
partition_field: str = None, | |
partitions: Iterable[str] = None, | |
row_filter="", | |
fields: List[str] = (), | |
read_timeout: int = 3600, | |
): | |
"""Read table as dask dataframe using BigQuery Storage API via Arrow format. | |
If `partition_field` and `partitions` are specified, then the resulting dask dataframe | |
will be partitioned along the same boundaries. Otherwise, partitions will be approximately | |
balanced according to BigQuery stream allocation logic. | |
If `partition_field` is specified but not included in `fields` (either implicitly by requesting | |
all fields, or explicitly by inclusion in the list `fields`), then it will still be included | |
in the query in order to have it available for dask dataframe indexing. | |
Args: | |
project_id: BigQuery project | |
dataset_id: BigQuery dataset within project | |
table_id: BigQuery table within dataset | |
partition_field: to specify filters of form "WHERE {partition_field} = ..." | |
partitions: all values to select of `partition_field` | |
fields: names of the fields (columns) to select (default None to "SELECT *") | |
read_timeout: # of seconds an individual read request has before timing out | |
Returns: | |
dask dataframe | |
See https://github.com/dask/dask/issues/3121 for additional context. | |
""" | |
if (partition_field is None) and (partitions is not None): | |
raise ValueError("Specified `partitions` without `partition_field`.") | |
# If `partition_field` is not part of the `fields` filter, fetch it anyway to be able | |
# to set it as dask dataframe index. We want this to be able to have consistent: | |
# BQ partitioning + dask divisions + pandas index values | |
if (partition_field is not None) and fields and (partition_field not in fields): | |
fields = (partition_field, *fields) | |
# These read tasks seems to cause deadlocks (or at least long stuck workers out of touch with | |
# the scheduler), particularly when mixed with other tasks that execute C code. Anecdotally | |
# annotating the tasks with a higher priority seems to help (but not fully solve) the issue at | |
# the expense of higher cluster memory usage. | |
with bigquery_client(project_id, with_storage_api=True) as ( | |
bq_client, | |
bqs_client, | |
), dask.annotate(priority=1): | |
table_ref = bq_client.get_table(".".join((dataset_id, table_id))) | |
if table_ref.table_type == "VIEW": | |
# Materialize the view since the operations below don't work on views. | |
logging.warning("Materializing view in order to read into dask. This may be expensive.") | |
query = f"SELECT * FROM `{full_id(table_ref)}`" | |
table_ref, _, _ = execute_query(query) | |
# The protobuf types can't be pickled (may be able to tweak w/ copyreg), so instead use a | |
# generator func. | |
def make_create_read_session_request(row_filter=""): | |
return bigquery_storage.types.CreateReadSessionRequest( | |
max_stream_count=0, # 0 -> use as many streams as BQ Storage will provide | |
parent=f"projects/{project_id}", | |
read_session=bigquery_storage.types.ReadSession( | |
data_format=bigquery_storage.types.DataFormat.ARROW, | |
read_options=bigquery_storage.types.ReadSession.TableReadOptions( | |
row_restriction=row_filter, | |
selected_fields=fields, | |
), | |
table=table_ref.to_bqstorage(), | |
), | |
) | |
# Create a read session in order to detect the schema. | |
# Read sessions are light weight and will be auto-deleted after 24 hours. | |
session = bqs_client.create_read_session( | |
make_create_read_session_request(row_filter=row_filter) | |
) | |
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(session.arrow_schema.serialized_schema)) | |
meta = schema.empty_table().to_pandas() | |
delayed_kwargs = dict(prefix=f"{dataset_id}.{table_id}-") | |
if partition_field is not None: | |
if row_filter: | |
raise ValueError("Cannot pass both `partition_field` and `row_filter`") | |
delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True) | |
if partitions is None: | |
logging.info( | |
"Specified `partition_field` without `partitions`; reading full table." | |
) | |
partitions = read_gbq( | |
f"SELECT DISTINCT {partition_field} FROM {dataset_id}.{table_id}", | |
project_id=project_id, | |
)[partition_field].tolist() | |
# TODO generalize to ranges (as opposed to discrete values) | |
partitions = sorted(partitions) | |
delayed_kwargs["divisions"] = (*partitions, partitions[-1]) | |
row_filters = [ | |
f'{partition_field} = "{partition_value}"' for partition_value in partitions | |
] | |
delayed_dfs = [ | |
_read_rows_arrow( | |
make_create_read_session_request=partial( | |
make_create_read_session_request, row_filter=row_filter | |
), | |
partition_field=partition_field, | |
project_id=project_id, | |
timeout=read_timeout, | |
) | |
for row_filter in row_filters | |
] | |
else: | |
delayed_kwargs["meta"] = meta | |
delayed_dfs = [ | |
_read_rows_arrow( | |
make_create_read_session_request=make_create_read_session_request, | |
project_id=project_id, | |
stream_name=stream.name, | |
timeout=read_timeout, | |
) | |
for stream in session.streams | |
] | |
return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs) | |
def dask_df_to_gbq( | |
ddf: dd.DataFrame, | |
project_id: str = None, | |
dataset_id: str = None, | |
table_id: str = None, | |
bq_schema: List[bigquery.schema.SchemaField] = None, | |
pa_schema: pyarrow.Schema = None, | |
partition_by: str = None, | |
cluster_by: List[str] = None, | |
clear_existing: bool = True, | |
retries: int = None, | |
write_index: bool = False, | |
): | |
"""Upload dask dataframe to BigQuery using Storage API via Arrow format. | |
Args: | |
ddf: dask dataframe to upload | |
project_id: BigQuery project | |
dataset_id: BigQuery dataset within project | |
table_id: BigQuery table within dataset | |
bq_schema: resulting table schema | |
TODO infer from data; load_table_from_dataframe tries but issues w/ some types | |
pa_schema: parquet schema | |
partition_by: (date or timestamp) field to partition by | |
cluster_by: field to cluster by | |
clear_existing: whether to delete the existing table | |
retries: number of retries for dask computation | |
write_index: whether to write index in parquet | |
TODO: Change this to only write to GCS parquet pattern, and have the framework handle | |
downstream resolution into a BQ view of GCS | |
""" | |
dask_tmp_pattern = "gs://model_bigquery_tmp/dask_dataframe_tmp/{token}/{timestamp}/*.parquet" | |
dask_tmp_path = dask_tmp_pattern.format(token=tokenize(ddf), timestamp=int(1e6 * time.time())) | |
logging.info(f"Writing dask dataframe to {dask_tmp_path} ...") | |
ddf.to_parquet( | |
path=os.path.dirname(dask_tmp_path), | |
engine="pyarrow", | |
write_index=write_index, | |
write_metadata_file=False, | |
schema=pa_schema, | |
) | |
with bigquery_client(project_id) as bq_client: | |
if table_id: | |
if not dataset_id: | |
raise ValueError("Cannot pass table_id without dataset_id") | |
dataset_ref = bq_client.create_dataset(dataset_id, exists_ok=True) | |
table_ref = dataset_ref.table(table_id) | |
if clear_existing: | |
bq_client.delete_table(table_ref, not_found_ok=True) | |
else: | |
table_ref = get_temporary_table(bq_client) | |
logging.info("Loading to temporary table %s", table_ref.table_id) | |
logging.info( | |
"Loading %s to %s.%s.%s ...", | |
dask_tmp_path, | |
table_ref.project, | |
table_ref.dataset_id, | |
table_ref.table_id, | |
) | |
job_config = bigquery.LoadJobConfig( | |
clustering_fields=cluster_by, | |
schema=bq_schema, | |
autodetect=(bq_schema is None), | |
source_format=bigquery.SourceFormat.PARQUET, | |
time_partitioning=( | |
bigquery.TimePartitioning(field=partition_by) if partition_by else None | |
), | |
write_disposition="WRITE_EMPTY", | |
) | |
job = bigquery.Client(project_id).load_table_from_uri( | |
source_uris=dask_tmp_path, | |
destination=table_ref, | |
job_config=job_config, | |
) | |
try: | |
return job.result() | |
except ClientError: | |
logging.error(f"Load job failed with the following errors: {job.errors}") | |
raise | |
def query_to_dask_df( | |
query: str, project_id: str = None, chunksize: int = None, tmp_path: str = None | |
) -> dd.DataFrame: | |
"""Read BigQuery result into dask dataframe using Parquet in GCS as an intermediary. | |
GCS export tends to have more balanced sharding and better performance compared to BigQuery | |
Storage API (used in `dask_df_to_gbq`), but unlike that approach does not allow for special | |
handling of partitioned data. | |
""" | |
if tmp_path is None: | |
tmp_path = f"gs://model_bigquery_tmp/{uuid4().hex}/*.parquet" | |
logging.info("Writing intermediate query_to_dask_df parquet files to %s", tmp_path) | |
query = f"EXPORT DATA OPTIONS(uri='{tmp_path}', format=PARQUET) AS\n{query}" | |
with bigquery_client(project_id) as bq_client: | |
job = bq_client.query(query) | |
job.result() # block until complete | |
with dask.annotate(retries=3): # Some reads seem to fail transiently - see RAD-1820. | |
ddf = dd.read_parquet(tmp_path) | |
if chunksize: | |
num_rows = int( | |
job.__dict__["_properties"]["statistics"]["query"]["exportDataStatistics"][ | |
"rowCount" | |
] | |
) | |
ddf = ddf.repartition(npartitions=max(num_rows // chunksize, 1)) | |
return ddf |
@ncclementi those look right, the bigquery client helper is
@contextmanager
def bigquery_client(project_id=_DEFAULT_BQ_PROJECT, with_storage_api=False):
# Ignore google auth credentials warning
warnings.filterwarnings(
"ignore", "Your application has authenticated using end user credentials"
)
bq_storage_client = None
bq_client = bigquery.Client(project_id)
try:
if with_storage_api:
bq_storage_client = bigquery_storage.BigQueryReadClient(
credentials=bq_client._credentials
)
yield bq_client, bq_storage_client
else:
yield bq_client
finally:
bq_client.close()
and get_temporary_table
is just
def get_temporary_table(bq_client):
return bq_client.get_table("tmp.{uuid.uuid4().hex}")
where tmp
is a dataset we use with a 1-day expiration policy.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@bnaul this is great, I'd like to try out this code but I'm missing imports and some information about dependencies.
Do you have a snippet of code/notebook that I can use to be able to run an example?
I'm guessing I need
It looks like you are also using
google-cloud-bigquery-storage
But a couple of things I'm not sure where are they coming from, would you mind pointing out where are these coming from? Are these custom functions?
bigquery_client()
as inwith bigquery_client(project_id)
get_temporary_table(bq_client)
Thanks in advance