Last active
August 3, 2023 18:25
-
-
Save viniciusarruda/ef463e9e04e2a221710a72d978d604c3 to your computer and use it in GitHub Desktop.
Chat completion wrapper to use with Hugging Face inference endpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from huggingface_hub import InferenceClient | |
from typing import List, Literal, TypedDict, Callable | |
Role = Literal["system", "user", "assistant"] | |
class Message(TypedDict): | |
role: Role | |
content: str | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
BOS, EOS = "<s>", "</s>" | |
DEFAULT_SYSTEM_PROMPT = """\ | |
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" | |
def _llama2_format_messages(messages: List[Message]) -> str: | |
if messages[0]["role"] != "system": | |
messages = [ | |
{ | |
"role": "system", | |
"content": DEFAULT_SYSTEM_PROMPT, | |
} | |
] + messages | |
messages = [ | |
{ | |
"role": messages[1]["role"], | |
"content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"], | |
} | |
] + messages[2:] | |
assert all([msg["role"] == "user" for msg in messages[::2]]) and all( | |
[msg["role"] == "assistant" for msg in messages[1::2]] | |
), ( | |
"model only supports 'system', 'user' and 'assistant' roles, " | |
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)" | |
) | |
formatted_messages: str = "".join( | |
[ | |
f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}" | |
for prompt, answer in zip( | |
messages[::2], | |
messages[1::2], | |
) | |
] | |
) | |
assert messages[-1]["role"] == "user", f"Last message must be from user, got {messages[-1]['role']}" | |
formatted_messages += f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}" | |
return formatted_messages | |
class Llama2ChatCompletionWrapper: | |
def __init__(self, callback: Callable[[Message], None] = None) -> None: | |
self.callback = callback | |
# Streaming Client | |
self.client = InferenceClient(os.environ["HF_ENDPOINT_URL"], token=os.environ["HF_TOKEN"]) | |
# generation parameter | |
self.default_gen_kwargs = dict( | |
max_new_tokens=512, | |
top_k=30, | |
top_p=0.9, | |
temperature=0.2, | |
repetition_penalty=1.02, | |
stop_sequences=["</s>"], | |
) | |
def new_session(self, system_content: str | None = None, messages: List[Message] | None = None): | |
self.messages: List[Message] = [] | |
# if self.callback is not None: | |
# self.callback() | |
if system_content is not None: | |
assert messages is None | |
self.messages.append(Message(role="system", content=system_content)) | |
if self.callback is not None: | |
self.callback(self.messages[-1]) | |
elif messages is not None: | |
self.messages = messages | |
if self.callback is not None: | |
for msg in self.messages: | |
self.callback(msg) | |
def __call__(self, message: str, post_process: Callable[[str], str] | None = None, **gen_kwargs) -> str: | |
self.messages.append(Message(role="user", content=message)) | |
if self.callback is not None: | |
self.callback(self.messages[-1]) | |
formatted_messages = _llama2_format_messages(self.messages) | |
params = dict(self.default_gen_kwargs, **gen_kwargs) # overwriting default parameters | |
generated_text = self.client.text_generation(formatted_messages, stream=False, details=False, **params) | |
result = generated_text.strip() | |
if post_process is not None: | |
# if self.callback is not None: | |
# self.callback() | |
result = post_process(result) | |
self.messages.append(Message(role="assistant", content=result)) | |
if self.callback is not None: | |
self.callback(self.messages[-1]) | |
return result | |
def console_print(message: Message) -> None: | |
reset = "\033[00m" | |
color_map = { | |
"system": ("\033[1;35m", "\033[35m"), | |
"user": ("\033[1;33m", "\033[33m"), | |
"assistant": ("\033[1;31m", "\033[31m"), | |
} | |
role_color, content_color = color_map[message["role"]] | |
formatted_message = f"{role_color}{message['role'].upper()}{reset}> {content_color}{message['content']}{reset}" | |
print(formatted_message) | |
if __name__ == "__main__": | |
params = dict(temperature=0.1, top_p=0.9, top_k=None, repetition_penalty=None) | |
llm = Llama2ChatCompletionWrapper(callback=console_print) | |
llm.new_session(system_content="You are a pirate! Think and speak like one!") | |
answer = llm("How old is the Earth?", **params) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Result after running with: