Skip to content

Instantly share code, notes, and snippets.

@ColeMurray
Created September 8, 2024 01:22
Show Gist options
  • Save ColeMurray/1e6bc35f0c0bb46fd87a995f77d741dd to your computer and use it in GitHub Desktop.
Save ColeMurray/1e6bc35f0c0bb46fd87a995f77d741dd to your computer and use it in GitHub Desktop.
Example using OpenHands to create a vector database server
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict
import faiss
import numpy as np
import os
app = FastAPI()
class QueryRequest(BaseModel):
query: List[float]
namespace: str
identifier: str
num_results: Optional[int] = 10
import pickle
class VectorDatabase:
def __init__(self, base_path: str = "./vector_dbs"):
self.base_path = base_path
self.databases: Dict[str, faiss.IndexFlatL2] = {}
self.vectors: Dict[str, List[np.ndarray]] = {}
self.ids: Dict[str, List[int]] = {}
os.makedirs(base_path, exist_ok=True)
self.load_all_dbs()
def db_path(self, namespace: str, identifier: str) -> str:
return os.path.join(self.base_path, f"{namespace}_{identifier}.index")
def metadata_path(self, namespace: str, identifier: str) -> str:
return os.path.join(self.base_path, f"{namespace}_{identifier}.metadata")
def load_all_dbs(self):
for filename in os.listdir(self.base_path):
if filename.endswith(".index"):
namespace, identifier = filename[:-6].split("_", 1)
self.load_db(namespace, identifier)
def load_db(self, namespace: str, identifier: str):
db_key = f"{namespace}_{identifier}"
index_path = self.db_path(namespace, identifier)
metadata_path = self.metadata_path(namespace, identifier)
print(f"Checking index path: {index_path}")
print(f"Checking metadata path: {metadata_path}")
if os.path.exists(index_path):
self.databases[db_key] = faiss.read_index(index_path)
if os.path.exists(metadata_path):
with open(metadata_path, 'rb') as f:
metadata = pickle.load(f)
self.vectors[db_key] = metadata['vectors']
self.ids[db_key] = metadata['ids']
print(f"Metadata file contents: {metadata}")
else:
print("Metadata file does not exist")
self.vectors[db_key] = []
self.ids[db_key] = []
print(f"Loaded database: {db_key}")
print(f"Number of vectors: {self.databases[db_key].ntotal}")
print(f"Number of IDs: {len(self.ids[db_key])}")
print(f"IDs: {self.ids[db_key]}")
else:
print(f"Index file does not exist: {index_path}")
def save_db(self, namespace: str, identifier: str):
db_key = f"{namespace}_{identifier}"
if db_key in self.databases:
index_path = self.db_path(namespace, identifier)
metadata_path = self.metadata_path(namespace, identifier)
faiss.write_index(self.databases[db_key], index_path)
with open(metadata_path, 'wb') as f:
pickle.dump({'vectors': self.vectors[db_key], 'ids': self.ids[db_key]}, f)
print(f"Saved database: {db_key}")
def create_or_get_db(self, namespace: str, identifier: str, dim: int):
db_key = f"{namespace}_{identifier}"
if db_key not in self.databases:
self.databases[db_key] = faiss.IndexFlatL2(dim)
self.vectors[db_key] = []
self.ids[db_key] = []
return self.databases[db_key]
def add_vector(self, namespace: str, identifier: str, vector: List[float], id: int):
db_key = f"{namespace}_{identifier}"
db = self.create_or_get_db(namespace, identifier, len(vector))
np_vector = np.array([vector], dtype=np.float32)
db.add(np_vector)
self.vectors[db_key].append(np_vector)
self.ids[db_key].append(id)
self.save_db(namespace, identifier)
def query(self, namespace: str, identifier: str, query_vector: List[float], k: int):
db_key = f"{namespace}_{identifier}"
if db_key not in self.databases:
raise HTTPException(status_code=404, detail="Database not found")
db = self.databases[db_key]
np_query = np.array([query_vector], dtype=np.float32)
distances, indices = db.search(np_query, k)
results = []
for idx, score in zip(indices[0], distances[0]):
if idx < len(self.ids[db_key]):
results.append({"id": self.ids[db_key][idx], "score": float(score)})
else:
results.append({"id": None, "score": float(score)})
return results
vector_db = VectorDatabase()
@app.post("/query")
def query_vector_db(request: QueryRequest):
results = vector_db.query(request.namespace, request.identifier, request.query, request.num_results)
return {"results": results}
@app.post("/add_vector")
def add_vector(namespace: str, identifier: str, vector: List[float], id: int):
vector_db.add_vector(namespace, identifier, vector, id)
return {"message": "Vector added successfully"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment