Skip to content

Instantly share code, notes, and snippets.

@spencerkclark
Created April 15, 2019 20:14
Show Gist options
  • Save spencerkclark/6a8e05a492111e52d8d8fb407d332611 to your computer and use it in GitHub Desktop.
Save spencerkclark/6a8e05a492111e52d8d8fb407d332611 to your computer and use it in GitHub Desktop.
xspharm

xspharm

Yet another wrapper for pyspharm...

Existing wrappers, windspharm and animal-spharm, lack dask support to parallelize operations. This wrapper enables use of dask at least for a subset of functions provided by pyspharm needed for computing closed tracer budgets following the methods of Hill et al. (2017):

Hill, S. A., Ming, Y., Held, I. M., & Zhao, M. (2017). A Moist Static Energy Budget–Based Analysis of the Sahel Rainfall Response to Uniform Oceanic Warming. Journal of Climate, 30(15), 5637–5660. https://doi.org/10.1175/JCLI-D-16-0785.1

Someday I may give this a more robust package structure and add some tests, but for now I'm simply posting it as a gist in case others find it useful for (a) dask/xarray-compatible spherical harmonic computations and (b) as a fairly-involved example use-case for xarray.apply_ufunc and dask.array.map_blocks.

Requirements

All of these requirements are available through conda-forge:

$ conda install -c conda-forge dask numpy pyspharm xarray
import dask.array as darray
import numpy as np
import xarray as xr
from spharm import Spharmt
LAT_STR = 'lat'
LON_STR = 'lon'
_HORIZONTAL_DIMS = (LAT_STR, LON_STR)
_NON_HORIZONTAL_DIM = 'non_horizontal'
_HARMONIC_DIM = 'harmonic'
RADIUS = 6370997.
def wraps_dask_array(arr):
return isinstance(arr.data, darray.core.Array)
def non_horizontal_dims(arr):
return set(arr.dims) - set(_HORIZONTAL_DIMS)
def stack_non_horizontal_dims(arr):
"""If present, stack all non-horizontal dims onto one dimension"""
dims_to_stack = tuple(non_horizontal_dims(arr))
if dims_to_stack:
arr = arr.stack(**{_NON_HORIZONTAL_DIM: dims_to_stack})
return arr
def order_dims_for_spharm(arr):
"""Order dims such that lat and lon come first"""
order = _HORIZONTAL_DIMS + tuple(non_horizontal_dims(arr))
return arr.transpose(*order)
def chunk_in_spherical_shells(arr):
"""If underlying data are chunked, rechunk to spherical shells"""
if wraps_dask_array(arr):
arr = arr.chunk({LAT_STR: arr.sizes[LAT_STR],
LON_STR: arr.sizes[LON_STR]})
return arr
def flip_lat(arr):
return arr.isel(**{LAT_STR: slice(None, None, -1)})
def orient_latitude_north_south(arr):
"""Orients data such that northern latitudes come first
Returns the transformed array as well as flag noting if the data were
flipped.
Parameters
----------
arr : xr.DataArray
Input DataArray
Returns
-------
xr.DataArray, bool
"""
if all(arr[LAT_STR].diff(LAT_STR) > 0.):
return flip_lat(arr), True
else:
return arr, False
def prep_for_spharm(arr):
"""Prepare DataArray for use with spharm
Parameters
----------
arr : xr.DataArray
Input DataArray
Returns
-------
xr.DataArray, bool
"""
arr = stack_non_horizontal_dims(arr)
arr = order_dims_for_spharm(arr)
arr = chunk_in_spherical_shells(arr)
return orient_latitude_north_south(arr)
def create_spharmt(arr):
return Spharmt(arr.sizes[LON_STR], arr.sizes[LAT_STR], rsphere=RADIUS,
gridtype='gaussian')
def n_harmonics(arr, n_trunc=None):
if n_trunc is None:
n_trunc = arr.sizes[LAT_STR] - 1
return (n_trunc + 1) * (n_trunc + 2) // 2
def _grdtospec(st, arr, harmonics):
"""Wrap Spharmt.grdtospec to be dask compatible"""
if isinstance(arr, darray.core.Array):
if arr.ndim == 3:
chunks = ((harmonics, ), arr.chunks[-1])
else:
chunks = ((harmonics, ))
return darray.map_blocks(st.grdtospec, arr, chunks=chunks,
dtype=np.complex, drop_axis=(0, 1),
new_axis=(0, ))
else:
return st.grdtospec(arr)
def grdtospec(arr, prepped=False):
"""Transform data from grid space to spectral space
Assumes data are on a Gaussian grid.
"""
if not prepped:
arr, flipped = prep_for_spharm(arr)
st = create_spharmt(arr)
_n_harmonics = n_harmonics(arr)
if _NON_HORIZONTAL_DIM in arr.dims:
output_core_dims = [[_HARMONIC_DIM, _NON_HORIZONTAL_DIM]]
output_sizes = {_HARMONIC_DIM: _n_harmonics,
_NON_HORIZONTAL_DIM: arr.sizes[_NON_HORIZONTAL_DIM]}
else:
output_core_dims = [[_HARMONIC_DIM]]
output_sizes = {_HARMONIC_DIM: _n_harmonics}
return xr.apply_ufunc(_grdtospec, st, arr, _n_harmonics,
input_core_dims=[[], arr.dims, []],
output_core_dims=output_core_dims,
output_sizes=output_sizes,
exclude_dims=set(_HORIZONTAL_DIMS), dask='allowed')
def _getu(st, vort, divg):
"""Wrap getuv to only return u"""
u, _ = st.getuv(vort, divg)
return u
def _getv(st, vort, divg):
"""Wrap getuv to only return v"""
_, v = st.getuv(vort, divg)
return v
def _getu_dask_allowed(st, vort, divg):
"""Wrap getuv to accept dask arrays and return just u"""
if isinstance(vort, darray.core.Array):
if vort.ndim == 2:
chunks = ((st.nlat, ), (st.nlon, ), vort.chunks[-1])
else:
chunks = ((st.nlat, ), (st.nlon, ))
return darray.map_blocks(_getu, st, vort, divg, chunks=chunks,
drop_axis=(0, ), new_axis=(0, 1))
else:
return _getu(st, vort, divg)
def _getv_dask_allowed(st, vort, divg):
"""Wrap getuv to accept dask arrays and return just v"""
if isinstance(vort, darray.core.Array):
if vort.ndim == 2:
chunks = ((st.nlat, ), (st.nlon, ), vort.chunks[-1])
else:
chunks = ((st.nlat, ), (st.nlon, ))
return darray.map_blocks(_getv, st, vort, divg, chunks=chunks,
drop_axis=(0, ), new_axis=(0, 1))
else:
return _getv(st, vort, divg)
def getuv(vort_grid, divg_grid):
"""Wrap getuv to accept *gridded* vorticity and divergence
This allows us to more readily reattach latitude and longitude coordinates
to the results.
"""
st = create_spharmt(vort_grid)
orig_coords = vort_grid.coords
vort_grid, _ = prep_for_spharm(vort_grid)
divg_grid, flipped = prep_for_spharm(divg_grid)
vort_spec = grdtospec(vort_grid, prepped=True)
divg_spec = grdtospec(divg_grid, prepped=True)
common_kwargs = {'input_core_dims': [[], vort_spec.dims, divg_spec.dims],
'output_core_dims': [divg_grid.dims],
'output_sizes': divg_grid.sizes,
'dask': 'allowed'}
u = xr.apply_ufunc(_getu_dask_allowed, st, vort_spec, divg_spec,
**common_kwargs)
v = xr.apply_ufunc(_getv_dask_allowed, st, vort_spec, divg_spec,
**common_kwargs)
if _NON_HORIZONTAL_DIM in u.dims:
u = u.unstack(_NON_HORIZONTAL_DIM)
v = v.unstack(_NON_HORIZONTAL_DIM)
if flipped:
u = flip_lat(u)
v = flip_lat(v)
for coord in orig_coords:
u[coord] = orig_coords[coord]
v[coord] = orig_coords[coord]
return u, v
def _getvrt(st, u, v):
"""Return the spectral version of the vorticty given gridded u, v"""
vort, _ = st.getvrtdivspec(u, v)
return vort
def _getdiv(st, u, v):
"""Return the spectral version of the divergence given gridded u, v"""
_, div = st.getvrtdivspec(u, v)
return div
def _getvrt_dask_allowed(st, u, v, harmonics):
if isinstance(u, darray.core.Array):
if u.ndim == 3:
chunks = ((harmonics, ), u.chunks[-1])
else:
chunks = ((harmonics, ))
return darray.map_blocks(_getvrt, st, u, v, chunks=chunks,
dtype=np.complex, drop_axis=(0, 1),
new_axis=(0,))
else:
return _getvrt(st, u, v)
def _getdiv_dask_allowed(st, u, v, harmonics):
if isinstance(u, darray.core.Array):
if u.ndim == 3:
chunks = ((harmonics, ), u.chunks[-1])
else:
chunks = ((harmonics, ))
return darray.map_blocks(_getdiv, st, u, v, chunks=chunks,
dtype=np.complex, drop_axis=(0, 1),
new_axis=(0,))
else:
return _getdiv(st, u, v)
def getvrtdivspec(u_grid, v_grid):
"""Wrap getvrtdiv to return vorticity and divergence in spectral space"""
st = create_spharmt(u_grid)
u_grid, _ = prep_for_spharm(u_grid)
v_grid, flipped = prep_for_spharm(v_grid)
_n_harmonics = n_harmonics(u_grid)
if _NON_HORIZONTAL_DIM in u_grid.dims:
output_core_dims = [[_HARMONIC_DIM, _NON_HORIZONTAL_DIM]]
output_sizes = {_HARMONIC_DIM: _n_harmonics,
_NON_HORIZONTAL_DIM: u_grid.sizes[_NON_HORIZONTAL_DIM]}
else:
output_core_dims = [[_HARMONIC_DIM]]
output_sizes = {_HARMONIC_DIM: _n_harmonics}
common_kwargs = {'input_core_dims': [[], u_grid.dims, v_grid.dims, []],
'output_core_dims': output_core_dims,
'exclude_dims': set(_HORIZONTAL_DIMS),
'output_sizes': output_sizes,
'dask': 'allowed'}
vort_spec = xr.apply_ufunc(_getvrt_dask_allowed, st, u_grid, v_grid,
_n_harmonics, **common_kwargs)
divg_spec = xr.apply_ufunc(_getdiv_dask_allowed, st, u_grid, v_grid,
_n_harmonics, **common_kwargs)
return vort_spec, divg_spec
def _spectogrd(st, arr):
"""Transform a variable from spectral to grid space"""
if isinstance(arr, darray.core.Array):
if arr.ndim == 2:
chunks = ((st.nlat, ), (st.nlon, ), arr.chunks[-1])
else:
chunks = ((st.nlat, ), (st.nlon, ))
return darray.map_blocks(st.spectogrd, arr, chunks=chunks,
drop_axis=(0, ), new_axis=(0, 1))
else:
return st.spectogrd(arr)
def getvrtdivgrid(u_grid, v_grid):
"""Wrap getvrtdiv to return *gridded* vorticity and divergence"""
st = create_spharmt(u_grid)
orig_coords = u_grid.coords
vort_spec, divg_spec = getvrtdivspec(u_grid, v_grid)
v_grid, flipped = prep_for_spharm(v_grid)
common_kwargs = {'input_core_dims': [[], vort_spec.dims],
'output_core_dims': [v_grid.dims],
'output_sizes': u_grid.sizes,
'dask': 'allowed'}
vort_grid = xr.apply_ufunc(_spectogrd, st, vort_spec, **common_kwargs)
divg_grid = xr.apply_ufunc(_spectogrd, st, divg_spec, **common_kwargs)
if flipped:
vort_grid = flip_lat(vort_grid)
divg_grid = flip_lat(divg_grid)
for coord in _HORIZONTAL_DIMS:
vort_grid[coord] = orig_coords[coord]
divg_grid[coord] = orig_coords[coord]
if _NON_HORIZONTAL_DIM in vort_grid.dims:
return (vort_grid.unstack(_NON_HORIZONTAL_DIM),
divg_grid.unstack(_NON_HORIZONTAL_DIM))
else:
return vort_grid, divg_grid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment