Created
September 5, 2023 05:31
-
-
Save j40903272/749fce7fafa7520d5b51b32f04a9a87b to your computer and use it in GitHub Desktop.
llama-2-13b-chat-transformers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import time | |
from typing import List | |
import logging | |
import shortuuid | |
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
DEFAULT_SYSTEM_PROMPT = """\ | |
You are a helpful, respectful and honest assistant.""" | |
MODEL_CONFIG = { | |
"model_path": "./llama-2-13b-chat-transformers", | |
"dtype": "bfloat16", | |
"tensor_parallel_size": 2, | |
"max_num_batched_tokens": 4096 | |
} | |
class LlamaModel(): | |
def __init__(self): | |
args = AsyncEngineArgs(MODEL_CONFIG.get("model_path")) | |
args.tokenizer = MODEL_CONFIG.get("model_path") | |
args.dtype = MODEL_CONFIG.get("dtype") | |
args.tensor_parallel_size = MODEL_CONFIG.get("tensor_parallel_size") | |
args.max_num_batched_tokens = MODEL_CONFIG.get( | |
"max_num_batched_tokens") | |
self.engine = AsyncLLMEngine.from_engine_args(args) | |
def format_prompt(self, messages: List[dict]): | |
if messages[0]["role"] != "system": | |
messages = [{ | |
"role": "system", | |
"content": DEFAULT_SYSTEM_PROMPT, | |
}] + messages | |
messages = [{ | |
"role": | |
messages[1]["role"], | |
"content": | |
B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"], | |
}] + messages[2:] | |
messages_list = [ | |
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()}" | |
for prompt, answer in zip(messages[::2], messages[1::2]) | |
] | |
messages_list.append( | |
f"{B_INST} {(messages[-1]['content']).strip()} {E_INST}") | |
return "".join(messages_list) | |
async def inference(self, messages: List[dict], max_tokens: int, | |
temperature: float): | |
start = time.perf_counter() | |
request_id = shortuuid.uuid() | |
logging.info(f"starting inference {request_id}: {messages}") | |
prompt = self.format_prompt(messages) | |
logging.info(f"prompt for request {request_id}: {prompt}") | |
sampling_params = SamplingParams(temperature=temperature, | |
max_tokens=max_tokens) | |
async for result in self.engine.generate( | |
prompt, sampling_params, | |
request_id): # assign to result on each async iteration | |
pass | |
e2e_inference_time = (time.perf_counter() - start) * 1000 | |
logging.info( | |
f"inference complete {request_id} in {e2e_inference_time} ms: {result.outputs[0].text}" | |
) | |
return result.outputs[0].text.strip() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment