Last active
January 20, 2023 20:16
-
-
Save Sam-Belliveau/d7b466085f91bc4e752d31dfab19b91a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union | |
import asyncio | |
from emoji import is_emoji | |
from time import time | |
from discord import Client, Intents | |
from openai import Completion | |
OPENAI_MODELS = { | |
'cheapest': 'text-ada-001', | |
'cheap': 'text-babbage-001', | |
'good': 'text-davinci-001', | |
'expensive': 'text-davinci-003', | |
} | |
OPENAI_CONFIG = { | |
'model': OPENAI_MODELS["good"], | |
'temperature': 1, | |
'top_p': 1, | |
'presence_penalty': 2, | |
'frequency_penalty': 2, | |
'n': 1, | |
} | |
WAIT_BEGIN_READ_SECS = 0.2 # Amount of time before the bot will begin reading | |
WAIT_PER_WORD_READ_SECS = 60 / 400 # How fast the robot can read a word | |
WAIT_PER_WORD_WRITE_SECS = 60 / 200 # How fast the robot can type a word | |
REACTION_TIME = 0.1 # Amount of time in between each emoji reaction | |
def _sleep_read_message(message: str): | |
return asyncio.sleep(WAIT_BEGIN_READ_SECS + WAIT_PER_WORD_READ_SECS * len(str(message).split())) | |
def _sleep_write_message(message: str, overhead: float = 0): | |
return asyncio.sleep(max(0, WAIT_PER_WORD_WRITE_SECS * len(str(message).split()) - overhead)) | |
def _sleep_react_emoji(): | |
return asyncio.sleep(REACTION_TIME) | |
def _refactor_message(msg: str) -> str: | |
return str(msg).replace("\n", " ").strip() | |
def _refactor_prompt(prompt: str) -> str: | |
return "\n".join(line.strip() for line in prompt.strip().splitlines()) | |
def _openai_get_response(prompt: str, length: int) -> str: | |
prompt = _refactor_prompt(prompt) | |
choices = Completion.create(prompt=prompt, max_tokens=length, **OPENAI_CONFIG).choices | |
return "" if not choices else str(choices[0].text).strip() | |
# Socially Awkward Messenger Bot | |
class SAMBot(Client): | |
def __init__(self, name: str, prompt: str, response_length, memory_length): | |
super().__init__(intents=Intents(68672, messages=True, message_content=True)) | |
self.prompt: str = _refactor_prompt(prompt) | |
self.name: str = name | |
self.response_length: int = response_length | |
self.memory_length: int = memory_length | |
self._msg_id: int = 0 | |
self._msg_id_handled: int = 0 | |
self._get_msg_returned(self._msg_id) | |
def _get_msg_handle(self): | |
self._msg_id += 1 | |
return self._msg_id | |
def _get_msg_invalid(self, id: int) -> bool: | |
return id < self._msg_id | |
def _get_msg_returned(self, id: int) -> None: | |
self._msg_id_handled = max(id + 1, self._msg_id_handled) | |
def get_message_emojis(self, messages: str) -> str: | |
prompt = f''' | |
{self.prompt} | |
{messages} | |
Which Emojis would {self.name} react to these messages with? (EMOJIS ONLY) | |
''' | |
response = _openai_get_response(prompt=prompt, length=16) | |
return "".join((e for e in response if is_emoji(e))) | |
def get_messages_response(self, messages: str) -> Union[None, str]: | |
prompt = f''' | |
{self.prompt} | |
{messages} | |
[{self.name}] | |
''' | |
# Sometimes the AI likes to respond to itself, this prevents that | |
response = _openai_get_response(prompt=prompt, length=self.response_length) | |
response = '\n'.join(( | |
line.strip() for line in response.splitlines() | |
if line and not (('[' in line) and (']' in line)) | |
)).strip() | |
return None if not response else response | |
def log(self, id: int, level: int, message: str) -> None: | |
tabbing = '\t' * level | |
print(f"[{id}]{tabbing}- {message}", flush=True) | |
async def on_ready(self) -> None: | |
self.log(0, 0, f'{self.name} is now running!') | |
async def print_config(self, channel) -> None: | |
await channel.send(_refactor_prompt(f''' | |
>>> __**<@{self.user.id}>'s Current Configuration:**__ | |
**1) Name:** `{self.name}` | |
**2) Prompt:** | |
``` | |
{self.prompt} | |
``` | |
**3) Memory Length:** `{self.memory_length} messages` | |
**4) Response Length:** `{self.response_length} tokens` | |
**5) OpenAI Model:** `{OPENAI_CONFIG['model']}` | |
''')) | |
async def on_message(self, message) -> None: | |
# update the current nickname if it has changed | |
async def get_nickname(user) -> str: | |
return (await message.guild.fetch_member(user.id)).nick | |
if (await get_nickname(self.user)) != self.name: | |
try: await (await message.guild.fetch_member(self.user.id)).edit(nick=self.name) | |
except Exception as _: pass | |
# do not respond to yourself | |
if message.author == self.user: return | |
# use try / finally block in order to make sure every single message completes | |
try: | |
# wait until the previous message has been completed | |
msg_handle = self._get_msg_handle() | |
while self._msg_id_handled < msg_handle: | |
await asyncio.sleep(REACTION_TIME) | |
if self.user in message.mentions: | |
mention = f"<@{self.user.id}>" | |
command = str(message.content) | |
command = command[command.find(mention) + len(mention):] + " " | |
self.log(msg_handle, 0, f"{self.name} Received Command!") | |
self.log(msg_handle, 1, f'"{command}"') | |
setname = command.lower().find("name") | |
setprompt = command.lower().find("prompt") | |
get = command.lower().find("get") | |
if 0 <= setname: | |
self.name = _refactor_message(command[setname + 4:]) | |
await self.print_config(message.channel) | |
return | |
if 0 <= setprompt: | |
self.prompt = _refactor_prompt(command[setprompt + 6:]) | |
await self.print_config(message.channel) | |
return | |
if 0 <= get: | |
await self.print_config(message.channel) | |
return | |
await message.channel.send(_refactor_prompt(f''' | |
>>> **{mention}'s Commands List:** `set name`, `set prompt`, `get config` | |
**Unknown Command:** `{command}` | |
''')) | |
return | |
self.log(msg_handle, 0, f"{self.name} Received Message!") | |
# if there is a newer message, don't send a request to OpenAI | |
if self._get_msg_invalid(msg_handle): return | |
reading_start = time() | |
reading = _sleep_read_message(message.content) | |
message_history = '\n'.join(reversed([ | |
f"[{await get_nickname(m.author)}] {_refactor_message(m.content)}" | |
async for m in message.channel.history(limit=self.memory_length) | |
if not self.user in m.mentions | |
])) | |
self.log(msg_handle, 1, f"Sending Open AI Message History:") | |
for line in message_history.splitlines(): | |
self.log(msg_handle, 2, line) | |
emoji_reactions = self.get_message_emojis(message_history) | |
await reading | |
self.log(msg_handle, 1, f"Simulated Reading Speed [{time() - reading_start:.2f}s]") | |
if emoji_reactions: | |
self.log(msg_handle, 1, f"Received Emoji Reactions from OpenAI: [{time() - reading_start:.2f}s]") | |
for emoji in emoji_reactions: | |
self.log(msg_handle, 2, emoji) | |
emoji_reaction = _sleep_react_emoji() | |
try: await message.add_reaction(emoji) | |
except Exception as _: pass | |
await emoji_reaction | |
# if there is a newer message, don't send a request to OpenAI | |
if self._get_msg_invalid(msg_handle): return | |
message_start = time() | |
async with message.channel.typing(): | |
response = self.get_messages_response(message_history) | |
self.log(msg_handle, 1, f"Received Response from OpenAI [{time() - message_start:.2f}s]") | |
await _sleep_write_message(response, time() - message_start) | |
self.log(msg_handle, 1,f"Simulated Typing Speed [{time() - message_start:.2f}s]") | |
# after simulated typing, if there is a new message, dont send it | |
# messages replying to old messages will look out of place | |
if self._get_msg_invalid(msg_handle): return | |
self.log(msg_handle, 1, "Sending Discord Response:") | |
if response: | |
self.log(msg_handle, 2, f"[{self.name}] {response}") | |
await message.channel.send(response) | |
else: | |
self.log(msg_handle, 2, f"No Response Given") | |
# handle every exception | |
except Exception as e: | |
self.log(msg_handle, 1, f"Exception Caught:") | |
for line in str(e).splitlines(): | |
self.log(msg_handle, 2, line) | |
# make sure that every message is marked as return | |
finally: | |
if self._get_msg_invalid(msg_handle): | |
self.log(msg_handle, 0, f"Message Interrupted By {self._msg_id - msg_handle} Newer Message(s)") | |
else: | |
self.log(msg_handle, 0, f"Message Response Complete!") | |
return self._get_msg_returned(msg_handle) | |
if __name__ == '__main__': | |
from os import getenv | |
bot = SAMBot("SamBot", "Respond to the following messages as SAMBot", | |
response_length=50, memory_length=8) | |
bot.run(getenv("DISCORD_TOKEN")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment