Skip to content

Instantly share code, notes, and snippets.

@pratikac
Last active October 6, 2021 18:01
Show Gist options
  • Save pratikac/68d6d94e4739786798e90691fb1a581b to your computer and use it in GitHub Desktop.
Save pratikac/68d6d94e4739786798e90691fb1a581b to your computer and use it in GitHub Desktop.
class View(nn.Module):
def __init__(self,o):
super().__init__()
self.o = o
def forward(self,x):
return x.view(-1, self.o)
class allcnn_t(nn.Module):
def __init__(self, c1=96, c2= 192):
super().__init__()
d = 0.5
def convbn(ci,co,ksz,s=1,pz=0):
return nn.Sequential(
nn.Conv2d(ci,co,ksz,stride=s,padding=pz),
nn.ReLU(True),
nn.BatchNorm2d(co))
self.m = nn.Sequential(
nn.Dropout(0.2),
convbn(3,c1,3,1,1),
convbn(c1,c1,3,1,1),
convbn(c1,c1,3,2,1),
nn.Dropout(d),
convbn(c1,c2,3,1,1),
convbn(c2,c2,3,1,1),
convbn(c2,c2,3,2,1),
nn.Dropout(d),
convbn(c2,c2,3,1,1),
convbn(c2,c2,3,1,1),
convbn(c2,10,1,1),
nn.AvgPool2d(8),
View(10))
print('Num parameters: ', sum([p.numel() for p in self.m.parameters()]))
def forward(self, x):
return self.m(x)
import os, sys, subprocess, json, argparse
from itertools import product
import torch as th
parser = argparse.ArgumentParser(description='Quick dirty hyperoptim')
parser.add_argument('-c','--command', help='Main command', type=str, required=True)
parser.add_argument('-p','--params', help='JSON dict of the hyper-parameters', type=str)
parser.add_argument('-r', '--run', help='run', action='store_true')
parser.add_argument('-j', '--max_jobs', help='max jobs', type=int, default = 1)
parser.add_argument('--dist', help='using dist sgd', action='store_true')
opt = vars(parser.parse_args())
def chunks(l, n):
for i in xrange(0, len(l), n):
yield l[i:i+n]
def run_cmds(cmds, max_cmds):
for cs in list(chunks(cmds, max_cmds)):
ps = []
try:
for c in cs:
p = subprocess.Popen(c, shell=True)
ps.append(p)
for p in ps:
p.wait()
except KeyboardInterrupt:
print 'Killling everything'
for p in ps:
p.kill()
sys.exit()
cmd = opt['command']
params = json.loads(opt['params'])
cmds = []
gs = range(th.cuda.device_count())
keys,values = zip(*params.items())
for v in product(*values):
p = dict(zip(keys,v))
s = ''
for k in p:
if len(k) > 1:
s += ' --'+k+' '+str(p[k])
else:
s += ' -'+k+' '+str(p[k])
c = cmd+s+' -l'
if not opt['dist']:
c = c + (' -g %d')%(gs[len(cmds)%len(gs)])
cmds.append(c)
if not opt['run']:
for c in cmds:
print c
else:
run_cmds(cmds, opt['max_jobs'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment