Skip to content

Instantly share code, notes, and snippets.

@swenson
Last active August 3, 2024 21:51
Show Gist options
  • Save swenson/feb7f8e4631e32437693ca1f02a0ad71 to your computer and use it in GitHub Desktop.
Save swenson/feb7f8e4631e32437693ca1f02a0ad71 to your computer and use it in GitHub Desktop.
Script to cleanup SRT transcripts. Run like `python3 clean.py input.srt output.srt`
from collections import namedtuple
from pprint import pprint
import math
import sys
Caption = namedtuple('Caption', ['start', 'end', 'text'])
def parse(l):
lines = l.split('\n')
start, end = lines[1].split(' --> ')
text = '\n'.join(lines[2:]).strip()
return Caption(start, end, text)
def fix_puncuation(lines):
new_lines = []
for i, l in enumerate(lines):
if i == 0 or len(l) == 1:
new_lines.append(l)
continue
if new_lines[-1].end == l.start:
if (not l.text[0].isalpha()) and l.text[1] == ' ':
new_lines[-1] = Caption(new_lines[-1].start, new_lines[-1].end, new_lines[-1].text.strip() + l.text[0])
new_lines.append(Caption(l.start, l.end, l.text[2:].strip()))
continue
if len(l.text) <= 3:
if l.text[0].isalnum():
new_lines[-1] = Caption(new_lines[-1].start, l.end, new_lines[-1].text.strip() + ' ' + l.text)
else:
new_lines[-1] = Caption(new_lines[-1].start, l.end, new_lines[-1].text.strip() + l.text)
continue
new_lines.append(l)
return new_lines
def split42(t):
if len(t) <= 42:
return t
i = t.rfind(' ', 0, 42)
if i < 0:
i = 42
return t[:i].strip() + '\n' + split42(t[i:].strip()).strip()
def fix_length(lines):
if not any(len(l.text) > 42 for l in lines):
return lines
new_lines = []
for l in lines:
# if len(l.text) > 42:
# i = l.text.rfind(' ', 0, 42)
# if i < 0:
# i = 42
# text = l.text[:i].strip() + '\n' +
new_lines.append(Caption(l.start, l.end, split42(l.text)))
# a = make_time(l.start)
# b = make_time(l.end)
# c = a + (b - a) * i / len(l.text)
# new_lines.append(Caption(time_str(a), time_str(c), l.text[:i].strip()))
# new_lines.append(Caption(time_str(c), time_str(b), l.text[i:].strip()))
# else:
# new_lines.append(l)
return new_lines
def clean(lines):
new_lines = []
for l in lines:
if '[' in l.text or '(' in l.text or ']' in l.text or ')' in l.text:
continue
if not new_lines:
new_lines.append(l)
continue
if new_lines[-1].text.strip() == l.text.strip():
if new_lines[-1].end == l.start:
new_lines[-1] = Caption(new_lines[-1].start, l.end, new_lines[-1].text)
continue
if len(l.text.strip()) == 1:
new_lines[-1] = Caption(new_lines[-1].start, l.end, new_lines[-1].text.strip() + l.text.strip())
continue
new_lines.append(l)
return fix_puncuation(fix_length(new_lines))
def make_time(x):
"""
>>> make_time('01:01:06,123')
3666.123
"""
seconds, millis = x.strip().split(',')
h, m, s = seconds.split(':')
return int(h) * 3600 + int(m) * 60 + int(s) + int(millis) / 1000
def time_str(x):
"""
>>> time_str(3666.123)
'01:01:06,123'
"""
seconds = int(math.floor(x))
millis = int((x - seconds) * 1000)
h = seconds // 3600
m = (seconds % 3600) // 60
s = seconds % 60
return f'{h:02d}:{m:02d}:{s:02d},{millis:03d}'
if __name__ == '__main__':
if len(sys.argv) < 2:
print('Usage: python3 clean.py input output')
sys.exit(1)
print(sys.argv[1])
lines = []
with open(sys.argv[1], 'rb') as fin:
txt = fin.read().decode('utf-8', errors='ignore')
lines = [parse(l.strip()) for l in txt.split('\n\n') if len(l.strip()) > 0]
# pprint(lines)
lines = clean(lines)
times = []
for i, line in enumerate(lines):
times.append((i, make_time(line.start), make_time(line.end)))
for i, start, end in times[1:]:
for j in range(i):
if times[j][2] > start:
print(f'overlap: {j} x {i}: {lines[j].text} {lines[i].text}')
check = False
with open(sys.argv[2], 'w') as fout:
for i, line in enumerate(lines):
fout.write(f'{i+1}\n{line.start} --> {line.end}\n{line.text}\n\n')
if i > 0 and not check:
if lines[i-1].text.strip() == line.text.strip():
print(f'check: {i}: {line.text.strip()}')
check = True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment