Last active
May 17, 2023 04:47
-
-
Save russelljjarvis/ab7d76c4e5f061e038bd9487234abcf7 to your computer and use it in GitHub Desktop.
NMNIST_DATA_INTO_SNN.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PyCall | |
using Revise | |
using Odesa | |
using Random | |
using ProgressMeter | |
using JLD | |
using NumPyArrays | |
using LoopVectorization | |
using Plots | |
function build_data_set_native(events,storage,cnt,input_shape,l_change_cnt,l_old) | |
xx = Vector{Int32}([]) | |
yy = Vector{Int32}([]) | |
tts = Vector{Float32}([]) | |
polarity = Vector{Int8}([]) | |
label = Vector{Int32}([]) | |
A = zeros((35,35)) | |
I = LinearIndices(A) | |
pop_stimulation= Vector{Int32}([]) | |
@inline for ev in events | |
(x,y,ts,p,l) = ev | |
push!(pop_stimulation,Int32(I[CartesianIndex(convert(Int32,x),convert(Int32,y))])) | |
push!(xx,convert(Int32,x)) | |
push!(yy,convert(Int32,y)) | |
ts = Float32(convert(Float32,ts)/1000.0) | |
push!(tts,ts) | |
push!(polarity,convert(Int8,p)) | |
l = convert(Int32,l) | |
push!(label,l) | |
cnt+=1 | |
end | |
did_it_exec::Tuple{Vector{Int32}, Vector{Int32}, Vector{Float32}, Vector{Int8}, Vector{Int32}, Vector{Any}} = (xx,yy,tts,polarity,label,pop_stimulation) | |
(cnt,did_it_exec,l_change_cnt,l_old) | |
end | |
function bds!() | |
pushfirst!(PyVector(pyimport("sys")."path"), "") | |
nmnist_module = pyimport("batch_nmnist_motions") | |
dataset::PyObject = nmnist_module.NMNIST("./") | |
training_order = 0:dataset.get_count()-1 | |
#storage::Array{Tuple{Vector{Int32}, Vector{Int32}, Vector{Float32}, Vector{Int8}, Vector{Int32}, Vector{Any}}} = [] | |
storage = [] | |
input_shape = dataset.get_element_dimensions() | |
cnt = 0 | |
l_change_cnt = 0 | |
l_old = 4 | |
@inbounds @showprogress for batch in 1:200:length(training_order) | |
events = dataset.get_dataset_item(training_order[batch:batch+1]) | |
cnt,did_it_exec,l_change_cnt,l_old = build_data_set_native(events,storage,cnt,input_shape,l_change_cnt,l_old) | |
@save "part_mnmist_$cnt.jld" did_it_exec | |
end | |
end | |
bds!() | |
@load "all_mnmist.jld" storage | |
(x,y,times,p,l,nodes) = (storage[1][1],storage[1][2],storage[1][3],storage[1][4],storage[1][5],storage[1][6]) | |
for (ind,s) in enumerate(storage) | |
(x,y,times,p,l,nodes) = (storage[s][1],storage[s][2],storage[s][3],storage[s][4],storage[s][5],storage[s][6]) | |
@show(unique(l)[1]) | |
end | |
display(Plots.scatter(times,nodes,markersize=0.1)) | |
""" | |
NMNIST_Motions dataset class | |
Provides access to the event-based motions NMIST dataset. This is a version of the | |
NMNIST dataset in which we have separated out the events by motion and linked them | |
to the specified MNIST images. This allows us to retrieve any motion from the dataset | |
along with the associated mnist image. | |
""" | |
from os.path import exists | |
#import h5py | |
import numpy as np | |
import tonic | |
from numpy.lib.recfunctions import merge_arrays | |
class NMNIST(object): | |
""" A Dataset interface to the NMNIST dataset """ | |
def __init__(self, dataset_path, train=True, first_saccade_only=False,transform=None): | |
""" Creates a dataset instance using the specified path """ | |
super(NMNIST, self).__init__() | |
# Validate the specified path and store it | |
if not exists(dataset_path): | |
raise Exception("Specified dataset path does not exist") | |
self._dataset = tonic.datasets.NMNIST(save_to='./data', train=train, first_saccade_only=first_saccade_only,transform=transform) | |
# self.transform = transform | |
def get_dataset_item(self, indices): | |
assert(len(indices) <= 100) | |
all_events = [] | |
for id,index in enumerate(indices): | |
(grid_x,grid_y) = np.unravel_index(id,(10,10)) | |
events, label = self._dataset[index] | |
label_array = np.full(events['x'].shape[0],label,dtype=[('label','i8')]) | |
event_array = merge_arrays((events,label_array),flatten=True) | |
event_array['x'] = event_array['x'] + 1 | |
event_array['y'] = event_array['y'] + 1 | |
all_events.append(event_array) | |
super_events = np.hstack(all_events) | |
super_events = super_events[super_events['t'].argsort()] | |
return super_events | |
def get_count(self): | |
""" Returns the number of items """ | |
return len(self._dataset) | |
def get_element_dimensions(self): | |
""" Returns a tuple containing the dimensions of each image in the dataset """ | |
return self._dataset.sensor_size |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment