Skip to content

Instantly share code, notes, and snippets.

@wbuchwalter
Last active April 2, 2020 14:37
Show Gist options
  • Save wbuchwalter/2983221a5a94b98edc787797832e2c96 to your computer and use it in GitHub Desktop.
Save wbuchwalter/2983221a5a94b98edc787797832e2c96 to your computer and use it in GitHub Desktop.
import pdb
import time
import argparse
import os
import csv
import json
from urllib.request import urlopen
from multiprocessing import Pool
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('input', help='Path to GCC TSV input file')
parser.add_argument('output', help='Output directory')
parser.add_argument('-t', help='Number of threads, default: 20', default=20, type=int)
parser.add_argument('-f', help='Frequency of checkpoints, default: 1000', default=1000, type=int)
args = parser.parse_args()
# Trailing slash fuck with basename
assert args.output[-1] != '/', "Output path must not end by a trailing slash"
# assert args.f % args.b == 0, "Frequency of checkpoint (%i) must be divisible by batch size (%i)" % (args.f, args.b)
REQUEST_TIMEOUT = 0.5 # 0.5 sec
# {
# "images": [{"id": 12, "filename": "00012.jpg"}],
# "annotations": [{"id": 12, "image_id": 12, "caption": "some piece of text"}]
# }
anns = {
'annotations': [],
'images': []
}
cursor_start = 0
cursor_pos = 0
dest_ann_file = os.path.join(args.output, 'captions_{}.json'.format(os.path.basename(args.output)))
dest_img_dir = os.path.join(args.output, os.path.basename(args.output))
# If cache file exists, seek over it to find the last processed index
# Then set start_cursor to this value. if idx < start_cursor fast forward
if os.path.isfile(dest_ann_file):
anns = json.load(open(dest_ann_file, 'r'))
cursor_start = anns['images'][-1]['id']
print('Output destination already exists, resuming download from image # %i...' % cursor_start)
#pdb.set_trace()
else:
os.makedirs(dest_img_dir, exist_ok=True)
def fetch_image_data(url):
response = urlopen(url, timeout=REQUEST_TIMEOUT)
if response.status != 200:
raise Exception('Bad status code')
img_data = response.read()
header = img_data[:11]
if (header[:3] != b'\xff\xd8\xff'): #or (header[6:] != b'JFIF\0'):
# If for some reason the header does not look like JPEG (a redirect for a broken image but 200 status for example) we skip
raise Exception('Corrupted image')
return img_data
def process_image(tup):
img_id, caption, url = tup
try:
img_data = fetch_image_data(url)
except Exception as e:
return None
img_filename = "{0:07d}.jpg".format(img_id)
open(os.path.join(dest_img_dir, img_filename), 'w+b').write(img_data)
ann = {"id": img_id, "image_id": img_id, "caption": caption}
img = {"id": img_id, "filename": img_filename}
return (ann, img)
with open(args.input, 'r') as tsvin:
buffer = []
tsvin = csv.reader(tsvin, delimiter='\t')
nb_failed = 0
for cursor_pos, (caption, url) in enumerate(tsvin):
if cursor_pos < cursor_start:
# Fast forward to cursor_start when resuming a download
continue
# Fill up a batch
buffer.append((cursor_pos, caption, url))
# if len(batch) < args.b:
# continue
processing_cursor = 0
with Pool(args.t) as p:
while processing_cursor < len(buffer):
t0 = time.time()
batch = buffer[processing_cursor : processing_cursor + args.f]
res = p.map(process_image, batch)
#batch = []
# discard failed downloads (404, invalid headers etc)
valid_data = np.array(list(filter(lambda x: x is not None, res)))
nb_failed_batch = len(res) - len(valid_data)
anns['annotations'].extend(valid_data[:, 0].tolist())
anns['images'].extend(valid_data[:, 1].tolist())
# if (cursor_pos + 1) % args.f == 0:
json.dump(anns, open(dest_ann_file, 'w+'))
processing_cursor += len(batch)
# if cursor_pos > 2000:
ti = time.time() - t0
print("Time for a batch:", ti)
print("%i images failed to download" % nb_failed_batch)
nb_failed += nb_failed_batch
print("%i images failed to download" % nb_failed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment