Skip to content

Instantly share code, notes, and snippets.

@selipot
Created September 22, 2023 14:34
Show Gist options
  • Save selipot/93c061f7715b95ce3ec3bcf169c1165a to your computer and use it in GitHub Desktop.
Save selipot/93c061f7715b95ce3ec3bcf169c1165a to your computer and use it in GitHub Desktop.
Using clouddrift.wavelet functions with xarray objects
# %%
import clouddrift
from clouddrift.wavelet import morse_wavelet, morse_wavelet_transform, wavelet_transform
import xarray as xr
import numpy as np
# %%
dat = xr.DataArray(data = np.random.random((1440,10,20))+1j*np.random.random((1440,10,20)),
coords={"time":np.arange(0,1440),
"x":np.arange(0,10),
"y":np.arange(0,20),
}
)
# %% It works by converting to numpy array of course
gamma = 3
beta = 4
radian_frequency = 2*np.pi*np.array([0.2])
wtp, wtn = morse_wavelet_transform(dat.to_numpy(),gamma,beta,radian_frequency,complex=True,time_axis=0)
# %% it works by apply apply_ufunc
# Here we cannot pass the complex case because it returns a tuple that apply_ufunc
# cannot process into a dataarray. So we use the function twice ()second time with conjugate
# and multiply the input by 0.5 as the function would do for complex case, with normalization bandpass
# even if the time_axis of the input dat is 0, it looks like apply_ufunc places
# the input core dimension last before applying the function, hence we specify
# time_axis = -1
wtp2 = xr.apply_ufunc(morse_wavelet_transform,0.5*dat,gamma,beta,radian_frequency,
input_core_dims=[["time"],[],[],[]],
output_core_dims=[["time"]],
kwargs={"time_axis":-1,"complex":False},
)
wtn2 = xr.apply_ufunc(morse_wavelet_transform,0.5*dat.conj(),gamma,beta,radian_frequency,
input_core_dims=[["time"],[],[],[]],
output_core_dims=[["time"]],
kwargs={"time_axis":-1,"complex":False},
)
print(np.allclose(wtp,np.moveaxis(wtp2.to_numpy(),-1,0)))
print(np.allclose(wtn,np.moveaxis(wtn2.to_numpy(),-1,0)))
#plt.plot(np.abs(np.fft.fft(wtp[:,0,0]))**2)
#plt.plot(np.abs(np.fft.fft(wtp2.to_numpy()[10,10,:]))**2)
# %% but morse_wavelet_transform recalculate the wavelet each time
# let's instead use morse_wavelet+wavelet_transform
wavelet,_ = morse_wavelet(1440,gamma,beta,radian_frequency)
wtp3 = xr.apply_ufunc(wavelet_transform,0.5*dat,wavelet,
input_core_dims=[["time"],[]],
output_core_dims=[["time"]],
kwargs={"time_axis":-1},
)
print(np.allclose(wtp,np.moveaxis(wtp3.to_numpy(),-1,0)))
# %% now let's try with dask arrays
import dask
# %% create a xarray dataarray with dask arrays
dat = xr.DataArray(data = np.random.random((1440,100,200))+1j*np.random.random((1440,100,200)),
coords={"time":np.arange(0,1440),
"x":np.arange(0,100),
"y":np.arange(0,200),
},
).chunk((1440,20,20))
# %%
wtp, wtn = morse_wavelet_transform(dat.to_numpy(),gamma,beta,radian_frequency,complex=True,time_axis=0)
wtp4 = xr.apply_ufunc(morse_wavelet_transform,0.5*dat,gamma,beta,radian_frequency,
input_core_dims=[["time"],[],[],[]],
output_core_dims=[["time"]],
kwargs={"time_axis":-1,"complex":False},
dask = "parallelized")
wtn4 = xr.apply_ufunc(morse_wavelet_transform,0.5*dat.conj(),gamma,beta,radian_frequency,
input_core_dims=[["time"],[],[],[]],
output_core_dims=[["time"]],
kwargs={"time_axis":-1,"complex":False},
dask = "parallelized")
# %%
print(np.allclose(wtp,np.moveaxis(wtp4.to_numpy(),-1,0)))
print(np.allclose(wtn,np.moveaxis(wtn4.to_numpy(),-1,0)))
# %% now with wavelet+wavelet_transform
# this doesn't work, not 100% sure why, maybe because wavelet is a np.array?
wavelet,_ = morse_wavelet(1440,gamma,beta,radian_frequency)
wtp5 = xr.apply_ufunc(wavelet_transform,0.5*dat,wavelet,
input_core_dims=[["time"],[]],
output_core_dims=[["time"]],
kwargs={"time_axis":-1},
dask = "parallelized").compute()
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment