Skip to content

Instantly share code, notes, and snippets.

@suryavanshi
Last active March 13, 2022 23:25
Show Gist options
  • Save suryavanshi/1a3f3fc49f7c6b95f96464120c49f105 to your computer and use it in GitHub Desktop.
Save suryavanshi/1a3f3fc49f7c6b95f96464120c49f105 to your computer and use it in GitHub Desktop.
from transformers import pipeline
import random
unmasker = pipeline('fill-mask', model='bert-base-cased')
input_text = "I went to see a movie in the theater"
orig_text_list = input_text.split()
len_input = len(orig_text_list)
#Random index where we want to replace the word
rand_idx = random.randint(1,len_input-1)
orig_word = orig_text_list[rand_idx]
new_text_list = orig_text_list.copy()
new_text_list[rand_idx] = '[MASK]'
new_mask_sent = ' '.join(new_text_list)
print("Masked sentence->",new_mask_sent)
#I went to [MASK] a movie in the theater
augmented_text_list = unmasker(new_mask_sent)
#To ensure new word and old word are not name
for res in augmented_text_list:
if res['token_str'] != orig_word:
augmented_text = res['sequence']
break
print("Augmented text->",augmented_text)
#I went to watch a movie in the theater
@KallisteCrts
Copy link

hey there! may I ask where orig_word is coming from? thank you!!

@suryavanshi
Copy link
Author

Hi, thanks for pointing it out, while copying from my notebook to gist, I forgot to include that line - 'orig_word' is initial the word at the random index, so its - orig_word = orig_text_list[rand_idx]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment