Last active
February 24, 2020 23:52
-
-
Save jakirkham/8a95d6a04c75342d8b89e82fe130407d to your computer and use it in GitHub Desktop.
Attempt at repro for multi-array return
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import dask | |
import dask.array | |
from dask.delayed import delayed | |
import distributed | |
try: | |
from dask_cuda import LocalCUDACluster as DaskCluster | |
client_kwargs = dict(protocol="ucx") | |
except ImportError: | |
from distributed import LocalCluster as DaskCluster | |
client_kwargs = dict(protocol="tcp") | |
try: | |
import cupy as xnumpy | |
except ImportError: | |
import numpy as xnumpy | |
@delayed | |
def double_halve(x): | |
return 2 * x, x / 2 | |
def get_client(): | |
cluster = DaskCluster() | |
client = distributed.Client(cluster, **client_kwargs) | |
return client | |
def main(*argv): | |
client = get_client() | |
rs = dask.array.random.RandomState(RandomState=xnumpy.random.RandomState) | |
a = rs.random((20,), chunks=(5,)) | |
arrs = [double_halve(e) for e in a.blocks] | |
futures = client.compute(arrs) | |
results = client.gather(futures) | |
print("") | |
print(results) | |
print("") | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main(*sys.argv)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment