Last active
June 19, 2024 03:15
-
-
Save failable/0379edf7a5d82024a69a50194295372f to your computer and use it in GitHub Desktop.
Qdrant viewer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import taipy.gui.builder as tgb | |
from openai import OpenAI | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import PayloadFieldSchema | |
from taipy.gui import Gui | |
OPENAI_EMBEDDING_MODEL = "text-embedding-ada-002" | |
def create_embeddings( | |
client: OpenAI, | |
doc: str, | |
model: str = "text-embedding-ada-002", | |
**kwargs, | |
) -> list[list[dict]]: | |
response = client.embeddings.create(input=doc, model=model, **kwargs) | |
return response.data[0].embedding | |
def create_filter(field: str, schema: PayloadFieldSchema, value: str) -> dict: | |
if schema.data_type == "text": | |
return { | |
"key": field, | |
"match": { | |
"text": value, | |
}, | |
} | |
if schema.data_type == "keyword": | |
return { | |
"key": field, | |
"match": { | |
"value": value, | |
}, | |
} | |
msg = f"Unsupported schema: {schema.data_type}" | |
raise ValueError(msg) | |
def get_collection_names(qdrant_client: QdrantClient) -> list[str]: | |
return sorted([x.name for x in qdrant_client.get_collections().collections]) | |
def get_payload_schema(qdrant_client: QdrantClient, collection_name: str) -> dict: | |
collection_info = qdrant_client.get_collection(collection_name) | |
filterable_payload_types = ["keyword", "text"] | |
payload_schema = { | |
k: v | |
for k, v in collection_info.payload_schema.items() | |
if v.data_type in filterable_payload_types | |
} | |
# NOTE Fixed field ordering for better user experience. | |
return dict( | |
sorted(payload_schema.items(), key=lambda x: (x[1].data_type, x[0])), | |
) | |
def on_qdrant_url_change(state, var, val): | |
global qdrant_client | |
qdrant_client = QdrantClient(val) | |
def on_collection_names_change(state, var, val): | |
refresh_filters(state) | |
def refresh_filters(state): | |
if state.qdrant_client and state.collection_name: | |
payload_schema = get_payload_schema( | |
state.qdrant_client, | |
state.collection_name, | |
) | |
if payload_schema: | |
with tgb.Page() as filter_part: | |
for field, schema in payload_schema.items(): | |
tgb.input(f"{field} ({str.capitalize(schema.data_type)})") | |
state.filter_partial.update_content(state, filter_part) | |
else: | |
state.filter_partial.update_content(state, "") | |
def on_init(state): | |
refresh_filters(state) | |
if __name__ == "__main__": | |
qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333") | |
qdrant_client = QdrantClient(qdrant_url) | |
collection_names = get_collection_names(qdrant_client) | |
collection_name = collection_names[0] if collection_names else None | |
with tgb.Page() as page: | |
with tgb.expandable("Options"): | |
qdrant_url_input = tgb.input( | |
value="{qdrant_url}", | |
label="Qdrant url", | |
on_change=on_qdrant_url_change, | |
) | |
with tgb.layout(columns="1 1"): | |
collection_names_selector = tgb.selector( | |
value="{collection_name}", | |
label="Collection", | |
lov=collection_names, | |
dropdown=True, | |
on_change=on_collection_names_change, | |
) | |
tgb.input(value=10, label="Number of results") | |
tgb.part(partial="{filter_partial}") | |
gui = Gui(page=page) | |
filter_partial = gui.add_partial("filter_partial") | |
gui.run( | |
title="Qdrant viewer", | |
dark_mode=False, | |
debug=True, | |
use_reloader=True, | |
port=5001, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment