Skip to content

Instantly share code, notes, and snippets.

@j40903272
Created September 5, 2023 05:31
Show Gist options
  • Save j40903272/749fce7fafa7520d5b51b32f04a9a87b to your computer and use it in GitHub Desktop.
Save j40903272/749fce7fafa7520d5b51b32f04a9a87b to your computer and use it in GitHub Desktop.
llama-2-13b-chat-transformers
import os
import sys
import time
from typing import List
import logging
import shortuuid
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant."""
MODEL_CONFIG = {
"model_path": "./llama-2-13b-chat-transformers",
"dtype": "bfloat16",
"tensor_parallel_size": 2,
"max_num_batched_tokens": 4096
}
class LlamaModel():
def __init__(self):
args = AsyncEngineArgs(MODEL_CONFIG.get("model_path"))
args.tokenizer = MODEL_CONFIG.get("model_path")
args.dtype = MODEL_CONFIG.get("dtype")
args.tensor_parallel_size = MODEL_CONFIG.get("tensor_parallel_size")
args.max_num_batched_tokens = MODEL_CONFIG.get(
"max_num_batched_tokens")
self.engine = AsyncLLMEngine.from_engine_args(args)
def format_prompt(self, messages: List[dict]):
if messages[0]["role"] != "system":
messages = [{
"role": "system",
"content": DEFAULT_SYSTEM_PROMPT,
}] + messages
messages = [{
"role":
messages[1]["role"],
"content":
B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
}] + messages[2:]
messages_list = [
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()}"
for prompt, answer in zip(messages[::2], messages[1::2])
]
messages_list.append(
f"{B_INST} {(messages[-1]['content']).strip()} {E_INST}")
return "".join(messages_list)
async def inference(self, messages: List[dict], max_tokens: int,
temperature: float):
start = time.perf_counter()
request_id = shortuuid.uuid()
logging.info(f"starting inference {request_id}: {messages}")
prompt = self.format_prompt(messages)
logging.info(f"prompt for request {request_id}: {prompt}")
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_tokens)
async for result in self.engine.generate(
prompt, sampling_params,
request_id): # assign to result on each async iteration
pass
e2e_inference_time = (time.perf_counter() - start) * 1000
logging.info(
f"inference complete {request_id} in {e2e_inference_time} ms: {result.outputs[0].text}"
)
return result.outputs[0].text.strip()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment