Last active
October 6, 2021 18:01
-
-
Save pratikac/68d6d94e4739786798e90691fb1a581b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class View(nn.Module): | |
def __init__(self,o): | |
super().__init__() | |
self.o = o | |
def forward(self,x): | |
return x.view(-1, self.o) | |
class allcnn_t(nn.Module): | |
def __init__(self, c1=96, c2= 192): | |
super().__init__() | |
d = 0.5 | |
def convbn(ci,co,ksz,s=1,pz=0): | |
return nn.Sequential( | |
nn.Conv2d(ci,co,ksz,stride=s,padding=pz), | |
nn.ReLU(True), | |
nn.BatchNorm2d(co)) | |
self.m = nn.Sequential( | |
nn.Dropout(0.2), | |
convbn(3,c1,3,1,1), | |
convbn(c1,c1,3,1,1), | |
convbn(c1,c1,3,2,1), | |
nn.Dropout(d), | |
convbn(c1,c2,3,1,1), | |
convbn(c2,c2,3,1,1), | |
convbn(c2,c2,3,2,1), | |
nn.Dropout(d), | |
convbn(c2,c2,3,1,1), | |
convbn(c2,c2,3,1,1), | |
convbn(c2,10,1,1), | |
nn.AvgPool2d(8), | |
View(10)) | |
print('Num parameters: ', sum([p.numel() for p in self.m.parameters()])) | |
def forward(self, x): | |
return self.m(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys, subprocess, json, argparse | |
from itertools import product | |
import torch as th | |
parser = argparse.ArgumentParser(description='Quick dirty hyperoptim') | |
parser.add_argument('-c','--command', help='Main command', type=str, required=True) | |
parser.add_argument('-p','--params', help='JSON dict of the hyper-parameters', type=str) | |
parser.add_argument('-r', '--run', help='run', action='store_true') | |
parser.add_argument('-j', '--max_jobs', help='max jobs', type=int, default = 1) | |
parser.add_argument('--dist', help='using dist sgd', action='store_true') | |
opt = vars(parser.parse_args()) | |
def chunks(l, n): | |
for i in xrange(0, len(l), n): | |
yield l[i:i+n] | |
def run_cmds(cmds, max_cmds): | |
for cs in list(chunks(cmds, max_cmds)): | |
ps = [] | |
try: | |
for c in cs: | |
p = subprocess.Popen(c, shell=True) | |
ps.append(p) | |
for p in ps: | |
p.wait() | |
except KeyboardInterrupt: | |
print 'Killling everything' | |
for p in ps: | |
p.kill() | |
sys.exit() | |
cmd = opt['command'] | |
params = json.loads(opt['params']) | |
cmds = [] | |
gs = range(th.cuda.device_count()) | |
keys,values = zip(*params.items()) | |
for v in product(*values): | |
p = dict(zip(keys,v)) | |
s = '' | |
for k in p: | |
if len(k) > 1: | |
s += ' --'+k+' '+str(p[k]) | |
else: | |
s += ' -'+k+' '+str(p[k]) | |
c = cmd+s+' -l' | |
if not opt['dist']: | |
c = c + (' -g %d')%(gs[len(cmds)%len(gs)]) | |
cmds.append(c) | |
if not opt['run']: | |
for c in cmds: | |
print c | |
else: | |
run_cmds(cmds, opt['max_jobs']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment