Skip to content

Instantly share code, notes, and snippets.

@marcoleewow
Last active April 30, 2019 03:45
Show Gist options
  • Save marcoleewow/6ac845fb000e15bd0dc5bdfb9b2f3ad0 to your computer and use it in GitHub Desktop.
Save marcoleewow/6ac845fb000e15bd0dc5bdfb9b2f3ad0 to your computer and use it in GitHub Desktop.
Error rate calculation for handwriting recognition. pip install editdistance==0.4
import editdistance
def cer(y_true: str, y_pred: str) -> float:
assert isinstance(y_true, str) and isinstance(y_pred, str)
n_err = editdistance.eval(y_true.rstrip(), y_pred.rstrip())
return n_err / len(y_true)
def wer(transcript: str, reference: str) -> float:
assert isinstance(transcript, str) and isinstance(reference, str)
transcript = transcript.split()
reference = reference.split()
n_err = editdistance.eval(transcript, reference)
return n_err / len(reference)
if __name__ == "__main__":
# exact case
cer_score = cer("hello world", "hello world")
wer_score = wer("hello world", "hello world")
assert cer_score == 0.
assert wer_score == 0.
# one deletion case
cer_score = cer("hello world", "hello world!")
wer_score = wer("hello world", "hello world!")
assert cer_score == 1. / 11.
assert wer_score == 1. / 2.
# one substitution case
cer_score = cer("hello world", "hallo world")
wer_score = wer("hello world", "hallo world")
assert cer_score == 1. / 11.
assert wer_score == 1. / 2.
# one insertion case
cer_score = cer("hello world", "ello world")
wer_score = wer("hello world", "ello world")
assert cer_score == 1. / 11.
assert wer_score == 1. / 2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment