Skip to content

Instantly share code, notes, and snippets.

@tarekziade
Created August 5, 2024 09:05
Show Gist options
  • Save tarekziade/42ba7d0952e9ed8326ba7bb526a92c1e to your computer and use it in GitHub Desktop.
Save tarekziade/42ba7d0952e9ed8326ba7bb526a92c1e to your computer and use it in GitHub Desktop.
from transformers import GPT2Tokenizer, AutoModelForVision2Seq
import requests
model_name = "mozilla/distilvit"
def load_words_from_url(url):
response = requests.get(url)
response.raise_for_status()
words = {line.strip() for line in response.text.splitlines()}
return words
# Load the bad words list
bad_words = load_words_from_url(
"https://raw.githubusercontent.com/snguyenthanh/better_profanity/master/better_profanity/profanity_wordlist.txt"
)
tokenizer_with_prefix_space = GPT2Tokenizer.from_pretrained(
model_name, add_prefix_space=True
)
def get_tokens_as_list(word_list):
tokens_list = []
for word in word_list:
tokenized_word = tokenizer_with_prefix_space(
[word], add_special_tokens=False
).input_ids[0]
tokens_list.append(tokenized_word)
return tokens_list
bad_word_ids = get_tokens_as_list(bad_words)
# save the new config on disk
model = AutoModelForVision2Seq.from_pretrained(model_name)
model.generation_config.update(bad_words_ids=bad_word_ids)
model.generation_config.to_json_file("generation_config.json")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment