Created
April 28, 2024 04:28
-
-
Save f0ster/db562bc608e3e0c45fa6441e5b36fc41 to your computer and use it in GitHub Desktop.
Shard Large LLM models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import sys | |
import torch | |
import glob | |
def load_parameters(directory): | |
""" Load model parameters from a JSON file. """ | |
with open(os.path.join(directory, 'params.json'), 'r') as file: | |
return json.load(file) | |
def verify_shard_compatibility(parameters, num_shards): | |
""" Ensure the parameter dimension is divisible by the number of shards. """ | |
if parameters['dim'] % num_shards != 0: | |
raise ValueError(f"Number of shards must divide parameter dimension {parameters['dim']}") | |
def load_checkpoints(directory): | |
""" Load all model checkpoints into CPU memory. """ | |
pattern = os.path.join(directory, '*.pth') | |
return [torch.load(path, map_location='cpu') for path in glob.glob(pattern)] | |
def shard_tensor(tensor, num_shards, dim): | |
""" Shard a tensor along the specified dimension. """ | |
slice_size = tensor.shape[dim] // num_shards | |
return [tensor[i * slice_size:(i + 1) * slice_size].clone().detach() for i in range(num_shards)] | |
def process_key(key, tensors, num_shards, layer_kind): | |
""" Process a single key of the model's parameters for sharding. """ | |
print(f'Processing {key}: in shapes = {[t.shape for t in tensors]}') | |
for pattern, (kind, is_column) in layer_kind.items(): | |
if key.replace('.weight', '').endswith(pattern): | |
merged_tensor = torch.cat(tensors, dim=0 if is_column else 1) | |
return shard_tensor(merged_tensor, num_shards, 0 if is_column else 1) | |
raise ValueError(f'Unrecognized parameter name: {key}') | |
def main(num_shards, input_model_dir, output_model_dir): | |
params = load_parameters(input_model_dir) | |
verify_shard_compatibility(params, num_shards) | |
checkpoints = load_checkpoints(input_model_dir) | |
output = [{} for _ in range(num_shards)] | |
layer_kind = { | |
# Define sharding behavior for different layers | |
'tok_embeddings': ('ParallelEmbedding', False), | |
'output': ('ColumnParallelLinear', True), | |
'attention.wq': ('ColumnParallelLinear', True), | |
'attention.wk': ('ColumnParallelLinear', True), | |
'attention.wv': ('ColumnParallelLinear', True), | |
'attention.wo': ('RowParallelLinear', False), | |
'feed_forward.w1': ('ColumnParallelLinear', True), | |
'feed_forward.w2': ('RowParallelLinear', False), | |
'feed_forward.w3': ('ColumnParallelLinear', True), | |
} | |
for key in checkpoints[0]: | |
tensors = [model[key] for model in checkpoints] | |
output_tensors = process_key(key, tensors, num_shards, layer_kind) | |
for rank, tensor in enumerate(output_tensors): | |
output[rank][key] = tensor | |
# Save the sharded models | |
os.makedirs(output_model_dir, exist_ok=True) | |
json.dump(params, open(os.path.join(output_model_dir, 'params.json'), 'w')) | |
for rank, data in enumerate(output): | |
torch.save(data, os.path.join(output_model_dir, f'consolidated.{rank:02d}.pth')) | |
if __name__ == "__main__": | |
if len(sys.argv) != 4: | |
print(f'Usage: {sys.argv[0]} <new-shards> <input-model-path> <output-model-path>', file=sys.stderr) | |
sys.exit(1) | |
main(int(sys.argv[1]), sys.argv[2], sys.argv[3]) | |
print('Sharding complete.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment