Skip to content

Instantly share code, notes, and snippets.

View awni's full-sized avatar

Awni Hannun awni

View GitHub Profile
@awni
awni / MLX_0_17_3.pdf
Last active September 15, 2024 22:02
MLX Documentation PDF Versions
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@awni
awni / resnet_mlx.py
Created September 7, 2024 20:02
MLX ResNet18 Inference Benchmark
from huggingface_hub import snapshot_download
import mlx.core as mx
import mlx.nn as nn
import time
class Block(nn.Module):
def __init__(self, in_dims, dims, stride=1):
super().__init__()
@awni
awni / fast_conway_mlx.py
Last active September 1, 2024 08:37
Conway's Game of Life Accelerated with Custom Kernels in MLX
import av
import numpy as np
import mlx.core as mx
def conway(a: mx.array):
source = """
uint i = thread_position_in_grid.x;
uint j = thread_position_in_grid.y;
uint n = threads_per_grid.x;
@awni
awni / mlx_api_prompt.py
Created August 20, 2024 15:43
Meta Llama 3.1 with MLX LM and the MLX Python API as Context
import os
import mlx.core as mx
from mlx_lm import load, generate
filename = os.path.join(os.path.dirname(mx.__file__), "core/__init__.pyi")
with open(filename, 'r') as fid:
prompt = fid.read()
prompt += "\nHow do you write a self-attention layer using the above API in MLX?"
model, tokenizer = load("mlx-community/meta-Llama-3.1-8B-Instruct-4bit")

Setup the repo

git clone git@github.com:filipstrand/mflux.git
cd mflux && pip install -r requirements.txt

Make a run script

Name this anything, maybe flux.py. Make sure to update the two paths marked below.

@awni
awni / l3min.py
Last active August 23, 2024 22:35
A minimal, fast implementation of Llama 3.1 in MLX.
"""
A minimal, fast example generating text with Llama 3.1 in MLX.
To run, install the requirements:
pip install -U mlx transformers fire
Then generate text with:
python l3min.py "How tall is K2?"
@awni
awni / metal_in_python.py
Last active August 12, 2024 20:56
Compile and call a Metal GPU kernel from Python
# Requires:
# pip install pyobjc-framework-Metal
import numpy as np
import Metal
# Get the default GPU device
device = Metal.MTLCreateSystemDefaultDevice()
# Make a command queue to encode command buffers to
command_queue = device.newCommandQueue()

Avoid Overly Frequent Graph Evaluations

MLX is lazy. No actual computation happens until you explicitly or implicitly evaluate the graph. Here are some ways that can happen:

  • Explicit call to mx.eval
  • Call a.item() on a scalar array
  • Convert an array to NumPy, i.e. np.array(a)
  • Print an array
from typing import Callable, Tuple
import operator
from functools import reduce
from itertools import product
import mlx.core as mx
def _interpolate(
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
):

MLX LM with the OpenAI Python Package

1. Install

Install MLX LM and openai:

pip install mlx-lm openai