Skip to content

Instantly share code, notes, and snippets.

@janakiramm
Created September 20, 2024 11:16
Show Gist options
  • Save janakiramm/30cdc8dda557379d76e85078f1fb48ef to your computer and use it in GitHub Desktop.
Save janakiramm/30cdc8dda557379d76e85078f1fb48ef to your computer and use it in GitHub Desktop.
NIM-LangChain-RAG-Agent
import os
from datetime import datetime, timedelta
import pytz
import requests
import json
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
# Environment setup
# Replace these with your actual API keys
os.environ["NVIDIA_API_KEY"] = "your_nvidia_api_key"
AEROAPI_BASE_URL = "https://aeroapi.flightaware.com/aeroapi"
AEROAPI_KEY = "your_aero_api_key"
# Initialize LLM
llm = ChatNVIDIA(model="meta/llama-3.1-405b-instruct")
# Flight status tool
@tool
def get_flight_status(flight_id: str):
"""
Returns flight information for a given flight ID.
"""
def get_api_session():
session = requests.Session()
session.headers.update({"x-apikey": AEROAPI_KEY})
return session
def fetch_flight_data(flight_id, session):
# Extract flight_id if it contains 'flight_id='
if "flight_id=" in flight_id:
flight_id = flight_id.split("flight_id=")[1]
# Define the time range for the API query
start_date = datetime.now().date().strftime('%Y-%m-%d')
end_date = (datetime.now().date() + timedelta(days=1)).strftime('%Y-%m-%d')
api_resource = f"/flights/{flight_id}?start={start_date}&end={end_date}"
# Make the API request
response = session.get(f"{AEROAPI_BASE_URL}{api_resource}")
response.raise_for_status()
flights = response.json().get('flights', [])
if not flights:
raise ValueError(f"No flight data found for flight ID {flight_id}.")
return flights[0]
def utc_to_local(utc_date_str, local_timezone_str):
utc_datetime = datetime.strptime(utc_date_str, '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=pytz.utc)
local_timezone = pytz.timezone(local_timezone_str)
local_datetime = utc_datetime.astimezone(local_timezone)
return local_datetime.strftime('%Y-%m-%d %H:%M:%S')
# Get session and fetch flight data
session = get_api_session()
flight_data = fetch_flight_data(flight_id, session)
# Determine departure and arrival keys
dep_key = 'estimated_out' if flight_data.get('estimated_out') else 'scheduled_out'
arr_key = 'estimated_in' if flight_data.get('estimated_in') else 'scheduled_in'
# Build flight details
flight_details = {
'source': flight_data['origin']['city'],
'destination': flight_data['destination']['city'],
'depart_time': utc_to_local(flight_data[dep_key], flight_data['origin']['timezone']),
'arrival_time': utc_to_local(flight_data[arr_key], flight_data['destination']['timezone']),
'status': flight_data['status']
}
return (
f"The current status of flight {flight_id} from {flight_details['source']} to {flight_details['destination']} "
f"is {flight_details['status']} with departure time at {flight_details['depart_time']} and arrival time at "
f"{flight_details['arrival_time']}."
)
# LLM with tools
llm_with_tools = llm.bind_tools([get_flight_status], tool_choice="required")
# Document loading and processing
def load_and_process_documents(url):
"""
Loads documents from a URL and splits them into chunks for processing.
"""
loader = WebBaseLoader(url)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
return text_splitter.split_documents(docs)
# Vector store setup
def setup_vector_store(documents):
"""
Sets up a vector store for document retrieval using embeddings.
"""
embeddings = NVIDIAEmbeddings()
vector_store = FAISS.from_documents(documents, embeddings)
return vector_store.as_retriever()
# Load and process documents
documents = load_and_process_documents(
"https://www.emirates.com/in/english/before-you-fly/baggage/cabin-baggage-rules/"
)
# Setup vector store and retriever
retriever = setup_vector_store(documents)
# Retrieval function
def retrieve(input_dict):
"""
Retrieves an answer based on the question and the context from documents.
"""
question = input_dict["question"]
docs = retriever.invoke(question)
context = " ".join(doc.page_content for doc in docs)
evaluation_prompt = (
f"Based on the following context, can you answer the question '{question}'? "
"If yes, provide the answer. If no, respond with 'Unable to answer based on the given context.'\n\n"
f"Context: {context}"
)
evaluation_messages = [HumanMessage(content=evaluation_prompt)]
evaluation_result = llm.invoke(evaluation_messages)
if "Unable to answer based on the given context" in evaluation_result.content:
final_answer = use_flight_status_tool(question)
else:
final_answer = evaluation_result.content.strip()
return {
"context": context,
"question": question,
"answer": final_answer
}
def use_flight_status_tool(question):
"""
Uses the flight status tool to answer flight status related questions.
"""
tool_messages = [HumanMessage(content=question)]
ai_msg = llm_with_tools.invoke(tool_messages)
if hasattr(ai_msg, 'tool_calls') and ai_msg.tool_calls:
tool_call = ai_msg.tool_calls[0]
try:
tool_name = tool_call['name'].lower()
tool_args = tool_call['args']
# Select and invoke the appropriate tool
selected_tool = {"get_flight_status": get_flight_status}[tool_name]
return selected_tool.invoke(tool_args['flight_id'])
except Exception as e:
return f"Error retrieving flight status: {str(e)}"
else:
return "Unable to retrieve flight status information."
# RAG chain setup
rag_chain = (
RunnablePassthrough()
| RunnableLambda(retrieve)
| (lambda x: x["answer"])
)
def process_question(question):
"""
Processes a question and returns an answer.
"""
return rag_chain.invoke({"question": question})
# Main execution
if __name__ == "__main__":
# Example usage
questions = [
"What is flight status of EK524?",
"What is the cabin baggage size?"
]
for question in questions:
result = process_question(question)
print(f"Question: {question}")
print(f"Answer: {result}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment