Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active July 14, 2024 19:58
Show Gist options
  • Save kastnerkyle/7b6be9cc2d77f1301b75fd2d8c1c894f to your computer and use it in GitHub Desktop.
Save kastnerkyle/7b6be9cc2d77f1301b75fd2d8c1c894f to your computer and use it in GitHub Desktop.
ROVER system combination algorithm
# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import unicode_literals
# author: Kyle Kastner
# References:
# needleman wunsch (could use other alignment algorithms instead)
# https://colab.research.google.com/github/zaneveld/full_spectrum_bioinformatics/blob/master/content/08_phylogenetic_trees/needleman_wunsch_alignment.ipynb
# ROVER discussion
# Data-Diverse Redundant Processing for Noise-Robust Automatic Speech Recognition
# Mustafa K. Hotaki
# https://libraetd.lib.virginia.edu/downloads/n583xv77p?filename=2_Hotaki_Mustafa_2020_MS.pdf
# Related ROVER code
# https://github.com/Toloka/CrowdSpeech/blob/main/rover.py
import numpy as np
from pprint import pprint
import string
from collections import OrderedDict, defaultdict
def pprint_arr(arr, seq_1=None, seq_2=None):
# print util
all_ = []
if seq_2 is not None:
# use ^ as the start of seq filler mark
all_.append([" ^"] + [el for el in seq_2])
for r_i in range(arr.shape[0]):
c_s = [el for el in arr[r_i]]
if seq_1 is not None:
if r_i == 0:
c_s = ["^"] + c_s
else:
c_s = [seq_1[r_i - 1]] + c_s
all_.append(c_s)
for r_i in range(arr.shape[0]):
print(*all_[r_i])
def needleman_wuntch_align(seq_1, seq_2):
n_rows = len(seq_1) + 1 #need an extra row up top
n_cols = len(seq_2) + 1 #need an extra column on the left
scoring_array = np.full([n_rows, n_cols], 0)
traceback_array = np.full([n_rows, n_cols], "-")
up_arrow = "\u2191"
right_arrow = "\u2192"
down_arrow = "\u2193"
left_arrow = "\u2190"
down_right_arrow = "\u2198"
up_left_arrow = "\u2196"
arrow = "-"
# INS and DEL penalty
gap_penalty = -1
# SUB penalty
mismatch_penalty = -1
match_bonus = 1
for r in range(n_rows):
for c in range(n_cols):
if r == 0 and c == 0:
# in the top left
score = 0
# 3 is "- aka start"
score_i = 3
elif r == 0:
# first row, non-corner
# score from left
l_c_s = scoring_array[r, c - 1]
score = l_c_s + gap_penalty
arrow = left_arrow
# 0 is "left"
score_i = 0
elif c == 0:
# first col, non-corner
# score from above
a_c_s = scoring_array[r - 1, c]
arrow = up_arrow
# 1 is "above"
score_i = 1
score = a_c_s + gap_penalty
else:
l_c_s = scoring_array[r, c - 1]
a_c_s = scoring_array[r - 1, c]
d_c_s = scoring_array[r - 1, c - 1]
score_f_l = l_c_s + gap_penalty
score_f_a = a_c_s + gap_penalty
score_f_d = d_c_s + (match_bonus if seq_1[r - 1] == seq_2[c - 1] else mismatch_penalty)
# 2 is "diag"
# note that order of this argmax list should match score_i descriptions above
grp = [score_f_l, score_f_a, score_f_d]
score_i = np.argmax(grp)
score = grp[score_i]
arrow_grp = [left_arrow, up_arrow, up_left_arrow]
arrow = arrow_grp[score_i]
scoring_array[r, c] = score
traceback_array[r, c] = arrow
return scoring_array, traceback_array
def traceback_alignment(traceback_array, seq1, seq2,
up_arrow="\u2191", left_arrow="\u2190", up_left_arrow="\u2196", stop="-", debug_print=False):
"""Align seq1 and seq2 using the traceback matrix and return as two strings
traceback_array -- a numpy array with arrow characters indicating the direction from
which the best path to a given alignment position originated
seq1 - a sequence represented as a string
seq2 - a sequence represented as a string
up_arrow - the unicode used for the up arrows (there are several arrow symbols in Unicode)
left_arrow - the unicode used for the left arrows
up_left_arrow - the unicode used for the diagonal arrows
stop - the symbol used in the upper left to indicate the end of the alignment
from:
https://colab.research.google.com/github/zaneveld/full_spectrum_bioinformatics/blob/master/content/08_phylogenetic_trees/needleman_wunsch_alignment.ipynb
"""
n_rows = len(seq1) + 1 #need an extra row up top
n_columns = len(seq2) + 1 #need an extra row up top
row = len(seq1)
col = len(seq2)
arrow = traceback_array[row,col]
aligned_seq1 = ""
aligned_seq2 = ""
alignment_indicator = ""
while arrow is not stop:
if debug_print:
print("Currently on row:",row)
print("Currently on col:",col)
arrow = traceback_array[row,col]
if debug_print:
print("Arrow:",arrow)
if arrow == up_arrow:
if debug_print:
print("insert indel into top sequence")
#We want to add the new indel onto the left
#side of the growing aligned sequence
aligned_seq2 = "-" + aligned_seq2
aligned_seq1 = seq1[row-1] + aligned_seq1
alignment_indicator = " "+alignment_indicator
row -=1
elif arrow == up_left_arrow:
if debug_print:
print("match or mismatch")
#Note that we look up the row-1 and col-1 indexes
#because there is an extra "-" character at the
#start of each sequence
seq1_character = seq1[row - 1]
seq2_character = seq2[col - 1]
aligned_seq1 = seq1[row - 1] + aligned_seq1
aligned_seq2 = seq2[col - 1] + aligned_seq2
if seq1_character == seq2_character:
alignment_indicator = "|"+alignment_indicator
else:
alignment_indicator = " "+alignment_indicator
row -=1
col -=1
elif arrow == left_arrow:
if debug_print:
print("Insert indel into left sequence")
aligned_seq1 = "-" + aligned_seq1
aligned_seq2 = seq2[col-1] + aligned_seq2
alignment_indicator = " " + alignment_indicator
col -=1
elif arrow == stop:
if debug_print:
print("Finished!")
break
else:
raise ValueError("Traceback array entry at {},{}: {} is not recognized as an up arrow ({}),left_arrow ({}), up_left_arrow ({}), or a stop ({}).".format(row, col, arrow, up_arrow, left_arrow, up_left_arrow, stop))
if debug_print:
print(aligned_seq1)
print(alignment_indicator)
print(aligned_seq2)
return aligned_seq1, aligned_seq2
class OrderedLambdaDefaultDict(OrderedDict):
factory = lambda: ([], 0)
def __missing__(self, key):
self[key] = value = self.factory()
return value
class WTN:
def __init__(self):
self.transitions = OrderedDict()
self.words_positions = set()
def _add_word(self, word_position):
if word_position not in self.words_positions:
self.words_positions.add(word_position)
if word_position not in self.transitions:
self.transitions[word_position] = OrderedDict()
def add_transition(self, from_word_position, to_word_position, weight=1.0):
self._add_word(from_word_position)
self._add_word(to_word_position)
if to_word_position not in self.transitions[from_word_position]:
self.transitions[from_word_position][to_word_position] = 0
self.transitions[from_word_position][to_word_position] += weight
def get_best_path_and_score(self, rover_alpha, confidence_fn):
best_paths = OrderedLambdaDefaultDict()
for word_position in self.words_positions:
best_paths[word_position] = ([], 0)
for from_word_position in self.transitions:
for to_word_position, weight in self.transitions[from_word_position].items():
current_path, current_score = best_paths[from_word_position]
new_score = current_score + (rover_alpha * (weight / float(len(self.transitions[from_word_position]))) + (1.0 - rover_alpha) * confidence_fn(current_path))
if new_score > best_paths[to_word_position][1]:
best_paths[to_word_position] = (current_path + [from_word_position], new_score)
# Find the best path by the maximum score
final_word = max(best_paths, key=lambda word: best_paths[word][1])
best_path, path_score = best_paths[final_word]
#print(self.transitions[best_path[-1][0]][best_path[-1][1]])
best_path.append(final_word)
# do we want the per step score totals?
return best_path, path_score
random_state = np.random.RandomState(2145)
def fake_confidences(current_preds):
if len(current_preds) > 0:
word_position = current_preds[-1]
# totally fake confidences
return random_state.rand()
if __name__ == "__main__":
# Example from
# https://libraetd.lib.virginia.edu/downloads/n583xv77p?filename=2_Hotaki_Mustafa_2020_MS.pdf
word_based = True
base_seq_1 = "the cat in the hat sat on the mat"
rover_alpha = 1.0
# rover_alpha = 1.0
# [(u'-', 0), (u'the', 1), (u'cat', 2), (u'in', 3), (u'the', 4), (u'hat', 5), (u'sat', 6), (u'on', 7), (u'the', 8), (u'mat', 9)]
# 22.1666666667
# rover_alpha = 0.0
# [(u'-', 0), (u'the', 1), (u'cat', 2), (u'in', 3), (u'the', 4), (u'hat', 5), (u'sat', 6), (u'on', 7), (u'-', 8), (u'mat', 9)]
# 3.94283457846
all_seq_2 = ["the cat and the hat on mat",
"the bat in that hat sat in the mat",
"the cat end hat on the mat",
"the cat end hat on the mat",
"the cat in the at sat on the mat",
"the cat in at sat on mat"]
core_wtn = WTN()
_s = 0
for base_seq_2 in all_seq_2:
if word_based:
# make these into words, instead of doing char based align
seq_1 = base_seq_1.split(" ")
seq_2 = base_seq_2.split(" ")
seq_1_o = seq_1
seq_2_o = seq_2
vocab = OrderedDict()
# set up 100k vocab, won't use most
indexer = [str(el) for el in range(100000)]
_i = 0
for seq in [seq_1, seq_2]:
for w in seq:
if w not in vocab:
vocab[w] = (1, indexer[_i])
_i += 1
else:
vocab[w] = (vocab[w][0], vocab[w][1])
rev_vocab = {v[1]: (v[0], k) for k, v in vocab.items()}
seq_1 = [vocab[el][1] for el in seq_1]
seq_2 = [vocab[el][1] for el in seq_2]
# get alignment and traceback for dynamic programming path
scoring_array, traceback_array = needleman_wuntch_align(seq_1, seq_2)
align_1, align_2 = traceback_alignment(traceback_array, seq_1, seq_2, debug_print=False)
if word_based:
align_1 = " ".join([rev_vocab[el][1] if el in rev_vocab else el for el in align_1])
align_2 = " ".join([rev_vocab[el][1] if el in rev_vocab else el for el in align_2])
# prepend "-"
words_align_1 = ["-"] + align_1.split(" ")
words_align_2 = ["-"] + align_2.split(" ")
words_align_1 = list(zip(words_align_1, range(len(words_align_1))))
words_align_2 = list(zip(words_align_2, range(len(words_align_2))))
else:
raise ValueError("Fix char based")
if _s == 0:
for fwa, twa in zip(words_align_1[:-1], words_align_1[1:]):
core_wtn.add_transition(fwa, twa)
for fwa, twa in zip(words_align_2[:-1], words_align_2[1:]):
core_wtn.add_transition(fwa, twa)
_s += 1
# now that we have all aligned sequences, build word transition network and do ROVER scoring
best_path, best_path_score = core_wtn.get_best_path_and_score(rover_alpha, fake_confidences)
print(best_path)
print(best_path_score)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment