Last active
January 23, 2023 05:08
-
-
Save zhangqiaorjc/7381b944bf1efdc7aa9897da3e453884 to your computer and use it in GitHub Desktop.
Experiment planning with NV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Decoder-only LM scaling experiments on GPUs.""" | |
from jax import numpy as jnp | |
from paxml import experiment_registry | |
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmd | |
from paxml.tasks.lm.params.lm_cloud import LmCloudSpmdPipeline | |
from praxis import layers | |
# TODO(zhangqiaorjc): Might need to use pmap instead of pjit for smaller models. | |
# TODO(zhangqiaorjc): Configure CHECKPOINT_POLICY for all experiments. | |
@experiment_registry.register | |
class NvidiaScaling1B(LmCloudSpmd): | |
"""Model with 1.3B params. | |
Global batch size = 4 * 16 * 8 = 512 | |
This config works on 16 hosts * 8 A100s. | |
""" | |
PERCORE_BATCH_SIZE = 4 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 32 | |
DIMS_PER_HEAD = 64 | |
MODEL_DIMS = 2048 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 24 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 128-way data parallelism. | |
ICI_MESH_SHAPE = [128, 1, 1] | |
@experiment_registry.register | |
class NvidiaScaling5B(LmCloudSpmd): | |
"""Model with 5B params. | |
Global batch size = 8 * 20 * 8 = 1280 | |
This config works on 20 hosts * 8 A100s. | |
""" | |
PERCORE_BATCH_SIZE = 8 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 32 | |
DIMS_PER_HEAD = 128 | |
MODEL_DIMS = 4096 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 24 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 80-way data parallelism, 2-tensor parallelism | |
ICI_MESH_SHAPE = [80, 1, 2] | |
@experiment_registry.register | |
class NvidiaScaling8B(LmCloudSpmd): | |
"""Model with 8.3B params. | |
Global batch size = 4 * 16 * 8 = 1280 | |
This config works on 16 hosts * 8 A100s. | |
""" | |
PERCORE_BATCH_SIZE = 4 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 64 | |
DIMS_PER_HEAD = 64 | |
MODEL_DIMS = 4096 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 40 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 32-way data parallelism, 4-tensor parallelism | |
ICI_MESH_SHAPE = [32, 1, 4] | |
@experiment_registry.register | |
class NvidiaScaling10B(LmCloudSpmd): | |
"""Model with 10B params. | |
Global batch size = 2.25 * 80 * 8 = 1440 | |
This config works on 80 hosts * 8 A100s. | |
""" | |
PERCORE_BATCH_SIZE = 2.25 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 40 | |
DIMS_PER_HEAD = 128 | |
MODEL_DIMS = 5120 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 32 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 80-way data parallelism, 8-tensor parallelism | |
ICI_MESH_SHAPE = [80, 1, 8] | |
@experiment_registry.register | |
class NvidiaScaling20B(LmCloudSpmd): | |
"""Model with 20B params. | |
Global batch size = 2.25 * 80 * 8 = 1440 | |
This config works on 80 hosts * 8 A100s. | |
""" | |
PERCORE_BATCH_SIZE = 2.25 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 48 | |
DIMS_PER_HEAD = 128 | |
MODEL_DIMS = 6144 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 44 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 80-way data parallelism, 8-tensor parallelism | |
ICI_MESH_SHAPE = [80, 1, 8] | |
@experiment_registry.register | |
class NvidiaScaling40B(LmCloudSpmdPipeline): | |
"""Model with 40B params. | |
Global batch size = 2.25 * 80 * 8 = 1440 | |
This config works on 80 hosts * 8 A100s. | |
""" | |
MICROBATCH_SIZE = 2 | |
PERCORE_BATCH_SIZE = 2.25 | |
NUM_STAGES = 4 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 48 | |
DIMS_PER_HEAD = 128 | |
MODEL_DIMS = 6144 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 44 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 20-way data, 4-way pipeline and 8-way model parallelism. | |
ICI_MESH_SHAPE = [1, 20, 1, 8] | |
DCN_MESH_SHAPE = [4, 1, 1, 1] | |
@experiment_registry.register | |
class NvidiaScaling116B(LmCloudSpmdPipeline): | |
"""Model with 116B params. | |
Global batch size = 1.5 * 128 * 8 = 1536 | |
This config works on 128 hosts * 8 A100s. | |
""" | |
MICROBATCH_SIZE = 2 | |
PERCORE_BATCH_SIZE = 1.5 | |
NUM_STAGES = 8 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 96 | |
DIMS_PER_HEAD = 128 | |
MODEL_DIMS = 12288 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 64 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 16-way data, 8-way pipeline and 8-way model parallelism. | |
ICI_MESH_SHAPE = [1, 16, 1, 8] | |
DCN_MESH_SHAPE = [8, 1, 1, 1] | |
@experiment_registry.register | |
class NvidiaScaling175B(LmCloudSpmdPipeline): | |
"""Model with 175B params. | |
Global batch size = 1.5 * 128 * 8 = 1536 | |
This config works on 128 hosts * 8 A100s. | |
""" | |
MICROBATCH_SIZE = 1 | |
PERCORE_BATCH_SIZE = 1.5 | |
NUM_STAGES = 8 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 96 | |
DIMS_PER_HEAD = 128 | |
MODEL_DIMS = 12288 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 96 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 16-way data, 8-way pipeline and 8-way model parallelism. | |
ICI_MESH_SHAPE = [1, 16, 1, 8] | |
DCN_MESH_SHAPE = [8, 1, 1, 1] | |
@experiment_registry.register | |
class GoogleScaling175B(LmCloudSpmdPipeline): | |
"""Model with 175B params. | |
Global batch size = 1.5 * 128 * 8 = 1536 | |
This config works on 128 hosts * 8 A100s. | |
""" | |
MICROBATCH_SIZE = 1 | |
PERCORE_BATCH_SIZE = 1.5 | |
NUM_STAGES = 8 | |
VOCAB_SIZE = 51200 | |
MAX_SEQ_LEN = 2048 | |
NUM_HEADS = 96 | |
DIMS_PER_HEAD = 128 | |
MODEL_DIMS = 12288 | |
HIDDEN_DIMS = MODEL_DIMS * 4 | |
NUM_LAYERS = 96 | |
FPROP_DTYPE = jnp.float32 | |
CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING | |
# 16-way data, 8-way pipeline and 8-way model parallelism. | |
ICI_MESH_SHAPE = [1, 16, 1, 8] | |
DCN_MESH_SHAPE = [8, 1, 1, 1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment