Last active
November 14, 2023 10:01
-
-
Save gbaeke/e6e88c0dc68af3aa4a89b1228012ae53 to your computer and use it in GitHub Desktop.
OpenAI Assistant Demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import openai | |
import dotenv | |
from PIL import Image | |
from io import BytesIO | |
import os | |
import pickle | |
from collections import OrderedDict | |
# Load environment variables | |
# .env should have OPENAI_API_KEY=your_open_ai_key | |
dotenv.load_dotenv() | |
# Initialize OpenAI Client | |
client = openai.Client() | |
CACHE_FILE = 'cache.pkl' | |
MAX_CACHE_SIZE = 20 | |
# Load cache from disk | |
if os.path.exists(CACHE_FILE): | |
with open(CACHE_FILE, 'rb') as f: | |
cache = pickle.load(f) | |
else: | |
cache = OrderedDict() | |
def get_content(file_id): | |
# If the content is in the cache, return it | |
if file_id in cache: | |
content = cache.pop(file_id) | |
cache[file_id] = content # Move the accessed entry to the end of the cache | |
return content | |
# If the content is not in the cache, fetch it | |
content = client.files.content(file_id).response.content | |
# Add the content to the cache | |
cache[file_id] = content | |
# If the cache is full, remove the least recently used item | |
if len(cache) > MAX_CACHE_SIZE: | |
cache.popitem(last=False) | |
# Save the cache to disk | |
with open(CACHE_FILE, 'wb') as f: | |
pickle.dump(cache, f) | |
return content | |
def main(): | |
# we want to store the assistant id in session state | |
if 'assistant' not in st.session_state: | |
# get assistant id from file | |
# note that this will result in errors if the assistant is deleted from platform.openai.com | |
if os.path.exists("assistant_id.txt"): | |
# read the assistant id from the file | |
with open("assistant_id.txt", "r") as f: | |
st.session_state.assistant_id = f.read() | |
else: | |
# create a new assistant with code interpreter tool | |
assistant = client.beta.assistants.create( | |
name="Math Tutor", | |
instructions="You are a personal math tutor. Write and run code to answer math questions.", | |
tools=[{"type": "code_interpreter"}], | |
model="gpt-4-1106-preview" | |
) | |
st.session_state.assistant_id = assistant.id | |
# assisten_id is in session state, write it to a file | |
with open("assistant_id.txt", "w") as f: | |
f.write(st.session_state.assistant_id) | |
# Create a new thread for the question; keep the thread in session state | |
if 'thread' not in st.session_state: | |
st.session_state.thread = client.beta.threads.create() | |
# Streamlit app title | |
st.title("Math Tutor Assistant") | |
user_input = st.text_input("Enter a math question") | |
if user_input: | |
# Add user question to thread | |
client.beta.threads.messages.create( | |
thread_id=st.session_state.thread.id, | |
role="user", | |
content=user_input | |
) | |
# Run the assistant with specific instructions | |
run = client.beta.threads.runs.create( | |
thread_id=st.session_state.thread.id, | |
assistant_id=st.session_state.assistant_id, # refer to assistant in session state | |
instructions="Please address the user as Geert. Only answer math questions." | |
) | |
# Wait for the assistant's response | |
with st.spinner('Waiting for completion...'): | |
run_status = 'pending' | |
while run_status != 'completed': | |
run = client.beta.threads.runs.retrieve( | |
thread_id=st.session_state.thread.id, | |
run_id=run.id | |
) | |
run_status = run.status | |
if run_status == 'failed' or run_status == "cancelled": | |
st.error("Run failed or cancelled") | |
st.stop() | |
# Retrieve and display the assistant's response | |
messages = client.beta.threads.messages.list( | |
thread_id=st.session_state.thread.id | |
) | |
# reverse messages list to show most recent messages at the bottom | |
# with user question and assistant response in right order | |
messages.data.reverse() | |
st.write("Most recent messages are at the bottom") | |
try: | |
# no support for file download yet, just text and image_file | |
for message in messages.data: | |
if message.role == 'user': | |
st.markdown(f"**User:** {message.content[0].text.value}") | |
if message.role == 'assistant': | |
for content in message.content: | |
if hasattr(content, 'text'): | |
st.markdown(f"**Assistant:** {content.text.value}") | |
elif hasattr(content, 'image_file'): | |
# image Id = content.image_file.file_id | |
content = get_content(content.image_file.file_id) | |
image = Image.open(BytesIO(content)) | |
st.image(image, caption="Downloaded Image", use_column_width=True) | |
except Exception as e: | |
st.error(e) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment