Skip to content

Instantly share code, notes, and snippets.

View kouroshHakha's full-sized avatar

kourosh hakhamaneshi kouroshHakha

View GitHub Profile
@kouroshHakha
kouroshHakha / create_test_dataset.py
Created December 11, 2023 05:06
JSON Mode and Function-calling on Open LLMs Blogpost
import datasets
import re
import json
import tqdm
ds = datasets.load_dataset("glaiveai/glaive-function-calling-v2", split="train")
out_ds_size = 100
class UserAssistantNotFoundError(Exception):
@kouroshHakha
kouroshHakha / fp16_vs_bf16_model_loading.py
Created October 20, 2023 17:31
Studies the diff on precision when loading in fp16 or bf16
from safetensors import safe_open
import torch
import numpy as np
import matplotlib.pyplot as plt
tensors = {}
model_ckpt = "/home/ray/default/7b-chat-lora-ckpt/adapter_model.safetensors"
with safe_open(model_ckpt, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
import torch
import torch.nn.functional as F
# Create random inputs for testing
batch_size = 128
seq_length = 512
embed_dim = 64
enable_math = False
query = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True)
@kouroshHakha
kouroshHakha / bm_attn.py
Created August 22, 2023 03:58
benchmark_flash
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
import pandas as pd
import os
from ray.train.huggingface import HuggingFacePredictor
import pandas as pd
import re
(ray) kourosh@kourosh-JRFKXJ33VL auto_prompting % python main.py
============== Trial 1 ===============
Current Prompt format:
I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show.
Failed Example:
Input:
pwd
Output:
import torch
import torch.nn.functional as F
import unittest
import xformers.ops as xops
import math
import time
MAX_ITER = 100
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pprint import pprint
import time
import gc
import matplotlib.pyplot as plt
import numpy as np
model_base = "gpt2"
#
# A fatal error has been detected by the Java Runtime Environment:
#
# SIGSEGV (0xb) at pc=0x00007f7a24d482ab, pid=38579, tid=0x00007f7a24b41340
#
# JRE version: OpenJDK Runtime Environment (Zulu 8.62.0.19-CA-linux64) (8.0_332-b09) (build 1.8.0_332-b09)
# Java VM: OpenJDK 64-Bit Server VM (25.332-b09 mixed mode linux-amd64 compressed oops)
# Problematic frame:
# C [libpthread.so.0+0x142ab] raise+0xcb
#
The one that works on CI machine
# https://console.anyscale-staging.com/o/anyscale-internal/configurations/app-config-details/bld_izp5d4ptbfkla4q1prbuxydpit
[INFO] 1/24/2023, 2:18:26 AM: * Step 19/28 (commit,modifyfs) : RUN pip install --upgrade --force-reinstall --no-cache-dir --use-deprecated=legacy-resolver "gymnasium[atari,mujoco]==0.26.3" "ale-py==0.8.0" "gym==0.26.2" "mujoco-py<2.2,>=2.1" "autorom[accept-rom-license]" && pip freeze (d6259223)
[INFO] 1/24/2023, 2:18:26 AM: Client TLS is disabled
[INFO] 1/24/2023, 2:18:26 AM: reg: localhost:5555 repo: anyscale/customer_image url: http://localhost:5555/v2/anyscale/customer_image/blobs/uploads/
[INFO] 1/24/2023, 2:18:26 AM: * Started pushing layer sha256:6d076eb12568fd23ca48a7a31640a2e2e40310f1981d2f1f026906ab687bc1f9
[INFO] 1/24/2023, 2:18:26 AM: Client TLS is disabled
[INFO] 1/24/2023, 2:18:27 AM: Collecting gymnasium[atari,mujoco]==0.26.3
[INFO] 1/24/2023, 2:18:27 AM: Downloading Gymnasium-0.26.3-py3-none-any.whl (836 kB)