Skip to content

Instantly share code, notes, and snippets.

@f0ster
Created April 28, 2024 04:28
Show Gist options
  • Save f0ster/db562bc608e3e0c45fa6441e5b36fc41 to your computer and use it in GitHub Desktop.
Save f0ster/db562bc608e3e0c45fa6441e5b36fc41 to your computer and use it in GitHub Desktop.
Shard Large LLM models
import os
import json
import sys
import torch
import glob
def load_parameters(directory):
""" Load model parameters from a JSON file. """
with open(os.path.join(directory, 'params.json'), 'r') as file:
return json.load(file)
def verify_shard_compatibility(parameters, num_shards):
""" Ensure the parameter dimension is divisible by the number of shards. """
if parameters['dim'] % num_shards != 0:
raise ValueError(f"Number of shards must divide parameter dimension {parameters['dim']}")
def load_checkpoints(directory):
""" Load all model checkpoints into CPU memory. """
pattern = os.path.join(directory, '*.pth')
return [torch.load(path, map_location='cpu') for path in glob.glob(pattern)]
def shard_tensor(tensor, num_shards, dim):
""" Shard a tensor along the specified dimension. """
slice_size = tensor.shape[dim] // num_shards
return [tensor[i * slice_size:(i + 1) * slice_size].clone().detach() for i in range(num_shards)]
def process_key(key, tensors, num_shards, layer_kind):
""" Process a single key of the model's parameters for sharding. """
print(f'Processing {key}: in shapes = {[t.shape for t in tensors]}')
for pattern, (kind, is_column) in layer_kind.items():
if key.replace('.weight', '').endswith(pattern):
merged_tensor = torch.cat(tensors, dim=0 if is_column else 1)
return shard_tensor(merged_tensor, num_shards, 0 if is_column else 1)
raise ValueError(f'Unrecognized parameter name: {key}')
def main(num_shards, input_model_dir, output_model_dir):
params = load_parameters(input_model_dir)
verify_shard_compatibility(params, num_shards)
checkpoints = load_checkpoints(input_model_dir)
output = [{} for _ in range(num_shards)]
layer_kind = {
# Define sharding behavior for different layers
'tok_embeddings': ('ParallelEmbedding', False),
'output': ('ColumnParallelLinear', True),
'attention.wq': ('ColumnParallelLinear', True),
'attention.wk': ('ColumnParallelLinear', True),
'attention.wv': ('ColumnParallelLinear', True),
'attention.wo': ('RowParallelLinear', False),
'feed_forward.w1': ('ColumnParallelLinear', True),
'feed_forward.w2': ('RowParallelLinear', False),
'feed_forward.w3': ('ColumnParallelLinear', True),
}
for key in checkpoints[0]:
tensors = [model[key] for model in checkpoints]
output_tensors = process_key(key, tensors, num_shards, layer_kind)
for rank, tensor in enumerate(output_tensors):
output[rank][key] = tensor
# Save the sharded models
os.makedirs(output_model_dir, exist_ok=True)
json.dump(params, open(os.path.join(output_model_dir, 'params.json'), 'w'))
for rank, data in enumerate(output):
torch.save(data, os.path.join(output_model_dir, f'consolidated.{rank:02d}.pth'))
if __name__ == "__main__":
if len(sys.argv) != 4:
print(f'Usage: {sys.argv[0]} <new-shards> <input-model-path> <output-model-path>', file=sys.stderr)
sys.exit(1)
main(int(sys.argv[1]), sys.argv[2], sys.argv[3])
print('Sharding complete.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment