Skip to content

Instantly share code, notes, and snippets.

@felipemello1
Last active September 13, 2024 14:40
Show Gist options
  • Save felipemello1/5f2002433c6da3a21f33d6cdf82e702a to your computer and use it in GitHub Desktop.
Save felipemello1/5f2002433c6da3a21f33d6cdf82e702a to your computer and use it in GitHub Desktop.
script update configs
"""
Script to update configs in torchtune in bulk
Goes over every .yaml file in configs that also has "lora" in the name
Finds the line that has "lora_alpha: 16"
Replaces the "lora_alpha: 16" with "lora_dropout: 0.0", while keeping the spacing and \n
Saves the file
Prints every file that was not updated
"""
import os
import shutil
def modify_yaml_file(file_path):
updated = False
with open(file_path, 'r') as file:
lines = file.readlines()
with open(file_path, 'w') as file:
for line in lines:
file.write(line)
# Check if the line contains 'lora_alpha: 16'
if 'lora_alpha: 16' in line:
# Create a new line by replacing 'lora_alpha: 16' with 'lora_dropout: 0.0'
new_line = line.replace('lora_alpha: 16', 'lora_dropout: 0.0')
# Write the new line to the file
file.write(new_line)
updated = True
return updated
def search_yaml_files(directory):
updated_files = []
not_updated_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.yaml'):
file_path = os.path.join(root, file)
if modify_yaml_file(file_path):
updated_files.append(file_path)
else:
not_updated_files.append(file_path)
print("Updated files:")
for file in updated_files:
print(file)
print("\nFiles not updated (no 'lora_alpha: 16' found):")
for file in not_updated_files:
print(file)
directory = 'recipes/configs'
search_yaml_files(directory)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment