Skip to content

Instantly share code, notes, and snippets.

View MekkCyber's full-sized avatar

Mohamed Mekkouri MekkCyber

View GitHub Profile
@MekkCyber
MekkCyber / bitblas_linear.py
Created September 1, 2024 17:31
BitBlas Linear Layer
import bitblas
from bitblas.cache import global_operator_cache, get_database_path
from bitblas import auto_detect_nvidia_target
BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = get_database_path()
import torch
import torch.nn.functional as F
import torch.nn as nn
# adapated from https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/utils_quant.py
class BitLinear158(nn.Module):
@MekkCyber
MekkCyber / kernel.py
Created August 29, 2024 14:24
Kernel for matmul while unpacking int2 weights
import torch
import triton
import triton.language as tl
def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
values_per_item = 8 // bits
packed_shape = packed.shape
if len(packed_shape) == 1: