Skip to content

Instantly share code, notes, and snippets.

@uogbuji
Last active August 30, 2024 13:45
Show Gist options
  • Save uogbuji/78ddddf9bf8e72dd44c5626321db681c to your computer and use it in GitHub Desktop.
Save uogbuji/78ddddf9bf8e72dd44c5626321db681c to your computer and use it in GitHub Desktop.
# Grr. Can't really use jupyter-as-script convention: https://twitter.com/uogbuji/status/1829325187965170154
# %%
'''
test/llm_struct_probe.py
'''
import asyncio
class llm_struct_probe:
def __init__(self, model_path):
self.model_path = model_path
self.model = Model()
self.model.load(model_path)
self.model_type = self.model.model.model_type
async def test(self):
async for chunk in self.test1(): print(chunk, end='')
async def test1(self):
sysprompt = ('You are a helpful assistant with access to a set of tool which you may '
"invoke to help respond to the user's request.\n"
"You may also choose not to use any of the tools, if you're sure they're not "
'useful for this response. In that case, you can fill out the\n'
'`toolio_none` pattern for your response\n'
'\n'
'\n'
'Tool name: today_kfabe\n'
' Get the current date\n'
'Invocation schema:\n'
'{"type": "object", "properties": {"name": {"type": "const", "const": '
'"today_kfabe"}, "arguments": {"type": "object", "properties": {}, '
'"required": []}}, "required": ["name", "arguments"]}\n'
'\n'
'Tool name: toolio_none\n'
' Call this tool to indicate that no other provided tool is useful for '
'responding to the user\n'
'Invocation schema:\n'
'{"type": "object", "properties": {"name": {"type": "const", "const": '
'"toolio_none"}, "arguments": {"type": "object", "properties": {"response": '
'{"type": "string", "description": "Your normal response to the user"}}}}, '
'"required": ["name", "arguments"]}\n'
'Your answer is a JSON array with one or more tool invocations according to '
'the appropriate schema(s),\n'
'or it follows the `toolio_none` pattern, as appropriate to respond to the '
"user's prompt below.\n")
prompt = 'Write me a haiku about AI'
tool_schemas = [{'properties': {'arguments': {'properties': {},
'required': [],
'type': 'object'},
'name': {'const': 'today_kfabe',
'type': 'const'}},
'required': ['name', 'arguments'],
'type': 'object'},
{'properties': {'arguments': {'properties': {'response': {'description': 'Your '
'normal '
'response '
'to '
'the '
'user',
'type': 'string'}},
'type': 'object'},
'name': {'const': 'toolio_none',
'type': 'const'}},
'required': ['name', 'arguments'],
'type': 'object'}]
full_schema = {'type': 'array', 'items': {'anyOf': tool_schemas}}
messages = [ {'role': 'system', 'content': sysprompt}, {'role': 'user', 'content': prompt} ]
responder = ToolCallResponder(self.model_path, self.model_type)
prompt_tokens = None
for result in self.model.completion(messages, full_schema, max_tokens=1024, temp=0.0, cache_prompt=False):
if result['op'] == 'evaluatedPrompt':
prompt_tokens = result['token_count']
elif result['op'] == 'generatedTokens':
message = responder.generated_tokens(result['text'])
if message:
yield message
elif result['op'] == 'stop':
completion_tokens = result['token_count']
yield responder.generation_stopped(
result['reason'], prompt_tokens, completion_tokens
)
else:
raise RuntimeError(f'Unknown result operation {result["op"]}')
# Below is just most of https://github.com/otriscon/llm-structured-output/blob/main/src/examples/llm_schema.py
"""
Example of JSON schema decoding with MLX.
"""
import argparse
import json
import time
from math import inf
from operator import itemgetter
from typing import Iterable, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.utils import load
from llm_structured_output import JsonSchemaAcceptorDriver
from llm_structured_output.util.bitmap import (
bias_logits,
count_set_bits,
enumerate_set_bits,
)
from llm_structured_output.util.output import info, bold, bolddim, debug
from llm_structured_output.util.tokenization import HuggingfaceTokenizerHelper
class RejectedCompletion(Exception):
"""
It's rare, but sometimes we reach a state where it's not possible to
advance the acceptor. For example, when closing a JSON string we get
a higher probability for slanted quotes than straight ones and select
the wrong token. At that point, the LLM will continue generating with
the prior that the string is closed, but our acceptor will remain in
the string-accepting state. This can indicate an issue with the
tokenizer vocabulary passed to the acceptor, or a bug in the code
used to decode tokens from the LLM. If none of these apply, check that
the LLM is actually able to generate JSON, although most are.
"""
class Model:
def __init__(self):
mx.random.seed(0)
self.model = None
self.tokenizer = None
self.vocabulary = None
self.eos_id = None
self.json_schema_acceptor_driver_factory = None
self._cached_prompt = None
self._cached_cache = None
def load(self, model_path: str):
"""
Load locally or download from Huggingface hub.
"""
self.model, tokenizer = load(model_path)
self.tokenizer = HuggingfaceTokenizerHelper(tokenizer)
self.vocabulary, self.eos_id = self.tokenizer.extract_vocabulary()
self.json_schema_acceptor_driver_factory = (
JsonSchemaAcceptorDriver.driver_factory_for_model(
self.vocabulary, self.eos_id
)
)
def get_driver_for_json_schema(self, schema, encapsulated: bool = False):
return self.json_schema_acceptor_driver_factory(
schema, is_encapsulated_json=encapsulated
)
def _evaluate_prompt(
self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None
):
if prior_prompt:
i = 0
for i, t in enumerate(prior_prompt):
# We need to leave at least one token to evaluate because we don't
# save the past logits.
if i >= len(prompt) - 1 or prompt[i] != t:
break
cache = prior_cache
for layer_cache in cache:
layer_cache.reuse(len(prompt), i)
tokens = prompt[i:]
else:
cache = ReusableKVCache.for_model(self.model)
tokens = prompt
logits = self.model(mx.array(tokens)[None], cache)
return logits, cache
def _decode(self, tokens):
return self.tokenizer.no_strip_decode(tokens)
def _debug_top_tokens(self, logits, count=10):
token_logits = sorted(
enumerate(logits.tolist()), key=itemgetter(1), reverse=True
)
top_tokens = [
(self._decode([t]), p) for t, p in token_logits[:count] if p != -inf
]
debug("TOP TOKENS:", top_tokens)
def _sample(self, logits, temp: float = 0):
if temp == 0:
result = mx.argmax(logits, axis=-1)
else:
result = mx.random.categorical(logits * (1 / temp))
return result.item()
def _sample_with_bias(
self, logits, temp: float = 0, token_acceptor=None, lazy_bias: bool = True
):
if token_acceptor is None:
return self._sample(logits, temp)
if lazy_bias:
token = self._sample(logits, temp)
try:
token_acceptor.advance_token(token)
return token
except JsonSchemaAcceptorDriver.TokenRejected:
pass
accepted_token_bitmap = token_acceptor.select_valid_tokens()
if not accepted_token_bitmap:
debug(token_acceptor.cursors)
self._debug_top_tokens(logits)
raise RejectedCompletion()
token = self._sample(bias_logits(mx, logits, accepted_token_bitmap), temp)
token_acceptor.advance_token(token)
return token
def generate_without_schema(self, logits, cache, temp: Optional[float] = 0.0):
"""
For testing / comparison purposes.
"""
while True:
tokens = [self._sample(logits[0, -1, :], temp)]
yield tokens
if tokens[-1] == self.eos_id:
break
logits = self.model(mx.array(tokens)[None], cache)
def generate_with_schema(
self, logits, cache, token_acceptor, temp: Optional[float] = 0.0
):
while True:
tokens = [self._sample_with_bias(logits[0, -1, :], temp, token_acceptor)]
yield tokens
if tokens[-1] == self.eos_id:
break
logits = self.model(mx.array(tokens)[None], cache)
def generate_with_preemptive_decoding(
self,
logits,
cache,
token_acceptor,
temp: Optional[float] = 0.0,
max_batch_size=5,
):
"""
Try to generate faster by precomputing two tokens at a time when possible.
If we know that the acceptor will only accept a small set of tokens after
the current one, we can evaluate a batch with one entry per possible
future token. Each entry in the batch contains the current token sampled,
which we have to evaluate anyway, and a second token corresponding to one
of the possible tokens that could be sampled from the output to the first
token. We get back logits for both tokens for each item in the batch: the
logits for the first token will be the same (as long as the model applies
a causal mask), and we can sample those logits to select from which of the
items in the batch we can select the second token.
In practice, this only seems to accelerate things for unquantized models.
"""
# Sample token from prompt evaluation
accepted_token_bitmap = token_acceptor.select_valid_tokens()
first_token_logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap)
first_token = self._sample(first_token_logits, temp)
tokens = [first_token]
yield tokens
token_acceptor.advance_token(first_token)
accepted_token_bitmap = token_acceptor.select_valid_tokens()
while True:
last_token = tokens[-1]
if count_set_bits(accepted_token_bitmap) in range(1, max_batch_size + 1):
# If the number of possible follow-up tokens is small, submit for
# evaluation a batch of 2-token continuations.
batch = []
for followup_token in enumerate_set_bits(accepted_token_bitmap):
batch.append([last_token, followup_token])
# Re-shape the cache to match the input.
for layer_cache in cache:
layer_cache.keys = mx.concatenate([layer_cache.keys] * len(batch))
layer_cache.values = mx.concatenate(
[layer_cache.values] * len(batch)
)
else: # Otherwise, submit the normal one-token continuation.
batch = [[last_token]]
logits = self.model(mx.array(batch), cache)
mx.eval(logits)
first_token_logits = bias_logits(mx, logits[0, 0, :], accepted_token_bitmap)
first_token = self._sample(first_token_logits, temp)
tokens = [first_token]
if first_token == self.eos_id:
yield tokens
break
token_acceptor.advance_token(first_token)
accepted_token_bitmap = token_acceptor.select_valid_tokens()
if not accepted_token_bitmap:
raise RejectedCompletion()
# If we had submitted 2-token continuations, we can decode a second token
if len(batch[0]) > 1:
index = next( # Find which of the second tokens was selected
i
for i, batch_item in enumerate(batch)
if batch_item[1] == first_token
)
second_token_logits = bias_logits(
mx, logits[index, 1, :], accepted_token_bitmap
)
second_token = self._sample(second_token_logits, temp)
tokens.append(second_token)
token_acceptor.advance_token(second_token)
accepted_token_bitmap = token_acceptor.select_valid_tokens()
# Select the accepted generation in the cache, restoring it to batch dimension 1.
for layer_cache in cache:
layer_cache.keys = layer_cache.keys.split([index, index + 1])[1]
layer_cache.values = layer_cache.values.split([index, index + 1])[1]
yield tokens
def _generate_tokens(
self,
generator: Iterable,
max_tokens: int = 1000,
) -> Iterable:
start_time = time.time_ns()
token_count = 0
for tokens in generator:
token_count += len(tokens)
try:
eos_index = tokens.index(self.eos_id)
tokens = tokens[0:eos_index]
except ValueError:
eos_index = -1
if tokens:
text = self._decode(tokens)
yield {
"op": "generatedTokens",
"text": text,
"token_count": len(tokens),
"time_ms": (time.time_ns() - start_time) / 1e6,
}
if eos_index >= 0:
yield {"op": "stop", "reason": "end"}
return
if token_count >= max_tokens:
yield {"op": "stop", "reason": "max_tokens"}
return
start_time = time.time_ns()
assert False
def completion(
self,
prompt: Union[str, Iterable[dict[str, str]]],
schema: dict,
encapsulated: bool = False,
max_tokens: int = 1000,
temp: float = 0.0,
seed: int = None,
preemptive_batch_size: int = 0,
cache_prompt: bool = False,
):
if seed is not None:
mx.random.seed(seed)
start_time = time.time_ns()
prompt_tokens = self.tokenizer.encode_prompt(prompt)
logits, cache = self._evaluate_prompt(
prompt_tokens, self._cached_prompt, self._cached_cache
)
if cache_prompt:
self._cached_prompt = prompt_tokens
self._cached_cache = cache
# Eager eval to more accurately reflect the prompt evaluation time.
mx.eval(logits)
prompt_time = time.time_ns() - start_time
yield {
"op": "evaluatedPrompt",
"prompt": prompt,
"token_count": len(prompt_tokens),
"time_ms": prompt_time / 1e6,
"prompt_tps": len(prompt_tokens) / (prompt_time / 1e9),
}
if schema:
token_acceptor = self.get_driver_for_json_schema(schema, encapsulated)
if preemptive_batch_size > 0:
generator = self.generate_with_preemptive_decoding(
logits,
cache,
token_acceptor,
temp,
max_batch_size=preemptive_batch_size,
)
else:
generator = self.generate_with_schema(
logits, cache, token_acceptor, temp
)
else:
generator = self.generate_without_schema(logits, cache, temp)
token_count = 0
generation_time = 0
for generation_result in self._generate_tokens(generator, max_tokens):
if generation_result["op"] == "generatedTokens":
token_count += generation_result["token_count"]
generation_time += generation_result["time_ms"]
elif generation_result["op"] == "stop":
generation_result["token_count"] = token_count
generation_result["time_ms"] = generation_time
# This is slightly incorrect, because the first token is generated
# from the prompt evaluation.
generation_result["generation_tps"] = token_count / (
generation_time / 1e3
)
yield generation_result
# Below is a big reduction from https://github.com/otriscon/llm-structured-output/blob/main/src/examples/server.py
class ChatCompletionResponder:
def __init__(self, model_name: str):
self.object_type = "chat.completion"
self.model_name = model_name
self.created = int(time.time())
self.id = f"{id(self)}_{self.created}"
self.content = ""
def message_properties(self):
return {
"object": self.object_type,
"id": f"chatcmpl-{self.id}",
"created": self.created,
"model": self.model_name,
}
def translate_reason(self, reason):
"""
Translate our reason codes to OpenAI ones.
"""
if reason == "end":
return "stop"
if reason == "max_tokens":
return "length"
return f"error: {reason}" # Not a standard OpenAI API reason
def format_usage(self, prompt_tokens: int, completion_tokens: int):
return {
"usage": {
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
"total_tokens": completion_tokens + prompt_tokens,
},
}
def generated_tokens(
self,
text: str,
):
self.content += text
return None
def generation_stopped(
self,
stop_reason: str,
prompt_tokens: int,
completion_tokens: int,
):
finish_reason = self.translate_reason(stop_reason)
message = {"role": "assistant", "content": self.content}
return {
"choices": [
{"index": 0, "message": message, "finish_reason": finish_reason}
],
**self.format_usage(prompt_tokens, completion_tokens),
**self.message_properties(),
}
class ToolCallResponder(ChatCompletionResponder):
def __init__(self, model_name: str, functions: list[dict]):
super().__init__(model_name)
def translate_reason(self, reason):
if reason == "end":
return "tool_calls"
return super().translate_reason(reason)
def generation_stopped(
self,
stop_reason: str,
prompt_tokens: int,
completion_tokens: int,
):
finish_reason = self.translate_reason(stop_reason)
if finish_reason == "tool_calls":
tool_calls = json.loads(self.content)
if not isinstance(tool_calls, list):
# len(functions) == 1 was special cased
tool_calls = [tool_calls]
message = {
"role": "assistant",
"tool_calls": [
{
"id": f"call_{self.id}_{i}",
"type": "function",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
}
for i, function_call in enumerate(tool_calls)
],
}
elif finish_reason == "function_call":
function_call = json.loads(self.content)
message = {
"role": "assistant",
"function_call": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
}
else:
message = None
return {
"choices": [
{"index": 0, "message": message, "finish_reason": finish_reason}
],
**self.format_usage(prompt_tokens, completion_tokens),
**self.message_properties(),
}
# Below from src/examples/reusable_kv_cache.py
from mlx_lm.models.base import KVCache
class ReusableKVCache(KVCache):
"""
Usability improvements over KVCache.
"""
@classmethod
def for_model(cls, model):
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
return [cls(model.head_dim, n) for n in kv_heads]
def reuse(self, new_prompt_length, common_prefix_length):
"""
Reuse (part of) this cache for a new prompt that shares a prefix with it.
"""
if self.keys is None:
return
# Clip the cache to the common length.
self.offset = common_prefix_length
# Make sure the cache can fit the whole prompt. Because the offset is
# (very likely) not a multiple of the step size, update_and_fetch()
# won't resize the cache when evaluating the rest of the prompt as it
# would if it were an empty cache.
current_size = self.keys.shape[2]
if current_size < new_prompt_length:
n_steps = (self.step + new_prompt_length - 1) // self.step
k_add_shape = (1, self.n_kv_heads, n_steps * self.step - current_size, self.k_head_dim)
v_add_shape = (1, self.n_kv_heads, n_steps * self.step - current_size, self.v_head_dim)
k_zeros = mx.zeros(k_add_shape, self.keys.dtype)
v_zeros = mx.zeros(v_add_shape, self.values.dtype)
self.keys = mx.concatenate([self.keys, k_zeros], axis=2)
self.values = mx.concatenate([self.values, v_zeros], axis=2)
def update_and_fetch(self, keys, values):
"""
Override the base class method to allow the cache to be used with batches of
size greater than 1.
This is just a tiny change in the line that determines the shape.
"""
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (keys.shape[0], self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (keys.shape[0], self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
lsp = llm_struct_probe('mlx-community/Hermes-2-Theta-Llama-3-8B-4bit')
# lsp = llm_struct_probe('mlx-community/Mistral-Nemo-Instruct-2407-4bit')
asyncio.run(lsp.test())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment