Skip to content

Instantly share code, notes, and snippets.

@vuiseng9
Created August 5, 2024 21:34
Show Gist options
  • Save vuiseng9/610dfe1e004bed65029f39d0f6c1064d to your computer and use it in GitHub Desktop.
Save vuiseng9/610dfe1e004bed65029f39d0f6c1064d to your computer and use it in GitHub Desktop.

Install

https://github.com/state-spaces/mamba

Run

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer, AutoModelForCausalLM
import torch
from functools import partial
from collections import OrderedDict, defaultdict
import os

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel


def annotate_module_static_attr(top_module, family_name=None):
    # static attr: 
    # first_name, last_name, class_name, is_leaf_module, leaf_has_weight
    if family_name is None:
        family = top_module.__class__.__name__.lower() + "class_as_family_name"
    else:
        family = family_name

    for parent_name, parent_module in top_module.named_modules():
        # handle top level because children loop below operate one level below, top level module will be missed 
        if parent_name == "":
            parent_module.first_name = family
            parent_module.last_name = ""

        for child_name, child_module in parent_module.named_children():
            child_module.first_name = child_name
            if parent_name == "":
                # just to handle the period if we dont do this conditional loop
                child_module.last_name = f"{family}"
            else:
                child_module.last_name = f"{family}.{parent_name}"
            
        # Following applies to every module
        parent_module.leaf_module = False
        if len(list(parent_module.children())) == 0:
            parent_module.is_leaf_module = True
            parent_module.leaf_has_weight = False
            if len(list(parent_module.parameters())) > 0:
                parent_module.leaf_has_weight = True

        parent_module.class_name = parent_module.__class__.__name__
        parent_module.full_name = f"{parent_module.last_name}.{parent_module.first_name}" # must be put at last

model_id="microsoft/Phi-3-mini-4k-instruct"
model_id="state-spaces/mamba-2.8b"
model_id="state-spaces/mamba2-2.7b"
# model_id="mistralai/Mamba-Codestral-7B-v0.1"

maxlen = 10
device = "cuda"
dtype = torch.float16

if model_id.startswith("mistralai/Mamba"):
    from huggingface_hub import snapshot_download
    from pathlib import Path

    mistral_models_path = Path.home().joinpath('mistral_models', 'Mamba-Codestral-7B-v0.1')
    if not mistral_models_path.exists():
        mistral_models_path.mkdir(parents=True, exist_ok=True)
        snapshot_download(repo_id="mistralai/Mamba-Codestral-7B-v0.1", allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], local_dir=mistral_models_path)
    exit()
    #TODO not working for "mistralai/Mamba-Codestral-7B-v0.1", doesnt work natively with HF transformers


is_mamba = model_id.startswith("state-spaces/mamba") or model_id.startswith("state-spaces/transformerpp") or model_id.startswith("mistralai/Mamba")

if is_mamba is True:
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    model = MambaLMHeadModel.from_pretrained(model_id, device=device, dtype=dtype)
else:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, torch_dtype=dtype)

top_module = model
annotate_module_static_attr(top_module=top_module, family_name=os.path.basename(model_id))
modtype_to_modlist = defaultdict(list)
modname_to_modtype = OrderedDict()
modname_to_module = OrderedDict()

for n, m in top_module.named_modules():
    modtype_to_modlist[m.class_name].append(f"{m.last_name}.{m.first_name}")
    modname_to_modtype[m.full_name] = m.class_name
    modname_to_module[m.full_name] = m


layer_dump = defaultdict(list)
def hook(module, input, output):
    layer_dump[module.full_name].append(
        dict(
            ifm=tuple(input[0].shape),
            wei=tuple(module.weight.shape),
            ofm=tuple(output.shape),
        )
    )

hooks = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        hooks.append(module.register_forward_hook(hook))

input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)

if is_mamba is True:
    out = model.generate(
        input_ids=input_ids,
        max_length=maxlen,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
        temperature=0.7,
        top_k=1,
        top_p=0.9,
        min_p=0.0,
        repetition_penalty=1.2,
    )
else:
    out = model.generate(input_ids, max_new_tokens=maxlen)

if is_mamba is True:
    print(tokenizer.batch_decode(out.sequences.tolist()))
else:
    print(tokenizer.batch_decode(out))

with open(f"layerwise_dump_{os.path.basename(model_id)}.csv", "w") as csvout:
    step=1
    for step in [0, 1]:
        for lid, l in enumerate(layer_dump):
            d = layer_dump[l][step]
            rpt_str = f"l{lid};{l};{step};i:{d['ifm']};w:{d['wei']};o:{d['ofm']}"

            print(rpt_str)
            csvout.write(rpt_str+"\n")
print("end.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment