Skip to content

Instantly share code, notes, and snippets.

@skrawcz
Created August 12, 2024 23:55
Show Gist options
  • Save skrawcz/a95989aa4fd1d9647b9c2633dc97301c to your computer and use it in GitHub Desktop.
Save skrawcz/a95989aa4fd1d9647b9c2633dc97301c to your computer and use it in GitHub Desktop.
Shows how to wrap a burr application for delegation to Ray. This is one possible strategy to make things run on Ray.
import copy
from IPython.display import Image, display
from IPython.core.display import HTML
import openai
from burr.core import ApplicationBuilder, State, default, graph, when
from burr.core.action import action
from burr.tracking import LocalTrackingClient
MODES = {
"answer_question": "text",
"generate_image": "image",
"generate_code": "code",
"unknown": "text",
}
@action(reads=[], writes=["chat_history", "prompt"])
def process_prompt(state: State, prompt: str) -> State:
result = {"chat_item": {"role": "user", "content": prompt, "type": "text"}}
state = state.append(chat_history=result["chat_item"])
state = state.update(prompt=prompt)
return state
@action(reads=["prompt"], writes=["mode"])
def choose_mode(state: State) -> State:
prompt = (
f"You are a chatbot. You've been prompted this: {state['prompt']}. "
f"You have the capability of responding in the following modes: {', '.join(MODES)}. "
"Please respond with *only* a single word representing the mode that most accurately "
"corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', "
"the mode would be 'generate_image'. If the prompt is "
"'what is the capital of France', the mode would be 'answer_question'."
"If none of these modes apply, please respond with 'unknown'."
)
llm_result = openai.Client().chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": prompt},
],
)
content = llm_result.choices[0].message.content
mode = content.lower()
if mode not in MODES:
mode = "unknown"
result = {"mode": mode}
return state.update(**result)
@action(reads=["prompt", "chat_history"], writes=["response"])
def prompt_for_more(state: State) -> State:
result = {
"response": {
"content": "None of the response modes I support apply to your question. "
"Please clarify?",
"type": "text",
"role": "assistant",
}
}
return state.update(**result)
@action(reads=["prompt", "chat_history", "mode"], writes=["response"])
def chat_response(
state: State, prepend_prompt: str, model: str = "gpt-3.5-turbo"
) -> State:
chat_history = copy.deepcopy(state["chat_history"])
chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}"
chat_history_api_format = [
{
"role": chat["role"],
"content": chat["content"],
}
for chat in chat_history
]
client = openai.Client()
result = client.chat.completions.create(
model=model,
messages=chat_history_api_format,
)
text_response = result.choices[0].message.content
result = {"response": {"content": text_response, "type": MODES[state["mode"]], "role": "assistant"}}
return state.update(**result)
@action(reads=["prompt", "chat_history", "mode"], writes=["response"])
def image_response(state: State, model: str = "dall-e-2") -> State:
"""Generates an image response to the prompt. Optional save function to save the image to a URL."""
# raise ValueError("Demo error")
client = openai.Client()
result = client.images.generate(
model=model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1
)
image_url = result.data[0].url
result = {"response": {"content": image_url, "type": MODES[state["mode"]], "role": "assistant"}}
return state.update(**result)
@action(reads=["response", "mode"], writes=["chat_history"])
def response(state: State) -> State:
# you'd do something specific here based on prior state
result = {"chat_item": state["response"]}
return state.append(chat_history=result["chat_item"])
# Built the graph.
base_graph = (
graph.GraphBuilder()
.with_actions(
# these are the "nodes"
prompt=process_prompt,
decide_mode=choose_mode,
generate_image=image_response,
generate_code=chat_response.bind(
prepend_prompt="Please respond with *only* code and no other text (at all) to the following:",
),
answer_question=chat_response.bind(
prepend_prompt="Please answer the following question:",
),
prompt_for_more=prompt_for_more,
response=response,
)
.with_transitions(
# these are the edges between nodes, based on state.
("prompt", "decide_mode", default),
("decide_mode", "generate_image", when(mode="generate_image")),
("decide_mode", "generate_code", when(mode="generate_code")),
("decide_mode", "answer_question", when(mode="answer_question")),
("decide_mode", "prompt_for_more", default),
(
["generate_image", "answer_question", "generate_code", "prompt_for_more"],
"response",
),
("response", "prompt", default),
)
.build()
)
# base_graph.visualize()
import ray
@ray.remote
def run_agent(user_input: str, app_id: str) -> State:
"""
Write a simple wrapper around creating an agent and calling it.
Use a persister for persistence between calls (postgres, etc support partition keys too)
Building the graph was above in a different cell. Here Ray is able to serialize the code the application
below references fine...
"""
tracker = LocalTrackingClient(project="agent-demo-ray") # I'm using it as a persister here.
app = (
ApplicationBuilder()
.with_graph(base_graph)
.initialize_from(
tracker,
resume_at_next_action=True,
default_state={"chat_history": []},
default_entrypoint="prompt",
)
.with_identifiers(app_id=app_id)
.with_tracker(tracker) # tracking + checkpointing/persisting; one line 🪄.
.build()
)
last_action, action_result, app_state = app.run(
halt_after=["response"],
inputs={"prompt": user_input}
)
return app_state
if __name__ == "__main__":
ray.init(ignore_reinit_error=True)
object_ref = run_agent.remote("what is the capital of france?", "test1234")
print(ray.get(object_ref))
# uses prior history because app_id is the same...
object_ref = run_agent.remote("write hello world in java", "test1234")
print(ray.get(object_ref))
@skrawcz
Copy link
Author

skrawcz commented Aug 12, 2024

I tested this in a notebook and copy pasted it here -- so it should work...

@skrawcz
Copy link
Author

skrawcz commented Aug 13, 2024

you could also directly pass in the state to the function, rather than using the persister.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment