Last active
May 4, 2016 18:43
-
-
Save cmey/e45ed9dc2b04bc216dcf032d50aeae99 to your computer and use it in GitHub Desktop.
Reikna FFT produces unexpected transposed result when either input or output is Fortran-ordered
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import numpy as np | |
from numpy.fft import fftn as npfftn | |
from numpy.testing import assert_array_almost_equal | |
import pyopencl as cl | |
import pyopencl.array | |
from reikna import cluda | |
from reikna.fft import FFT as reikna_FFT | |
# Select device | |
platform = 0 | |
device = 0 | |
# Init OpenCL | |
cl_platforms = cl.get_platforms() | |
cl_devices = cl_platforms[platform].get_devices(cl.device_type.GPU) | |
cl_device = cl_devices[device] | |
cl_context = cl.Context(devices=[cl_device]) | |
cl_queue = cl.CommandQueue(cl_context, properties=cl.command_queue_properties.PROFILING_ENABLE) | |
# Input initialized with random data. | |
data_shape = (2, 2) | |
np.random.seed(1) | |
np_array_C = np.random.random(data_shape).astype(dtype=np.complex64) | |
np_array_F = np.asfortranarray(np_array_C) | |
# Output initialized with zeros. | |
np_result_C = np.zeros_like(np_array_C, dtype=np.complex64) | |
np_result_F = np.zeros_like(np_array_F, dtype=np.complex64) | |
axes = (1, 0) | |
for np_in, np_out in itertools.product([np_array_C, np_array_F], [np_result_C, np_result_F]): | |
# Reset input data on device. | |
cl_data_in = cl.array.to_device(cl_queue, np_in) | |
# Reset output data on device. | |
cl_data_out = cl.array.to_device(cl_queue, np_out) | |
# Prepare FFT. | |
reikna_thread = cluda.ocl_api().Thread(cl_queue) | |
prepared = reikna_FFT(cl_data_in, axes=axes).compile(reikna_thread, fast_math=True) | |
# Call FFT. | |
prepared(cl_data_out, cl_data_in, inverse=False) | |
cl_to_compare = cl_data_out | |
npfft_to_compare = npfftn(np_in, axes=axes) | |
assert_array_almost_equal(cl_to_compare.get(), npfft_to_compare, decimal=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment