-
-
Save claysauruswrecks/ff68efd81b98401b44456e0f25c41f76 to your computer and use it in GitHub Desktop.
# main | |
llama-index | |
langchain |
"""Modified llama-hub example for github_repo""" | |
import argparse | |
import logging | |
import os | |
import pickle | |
from langchain.chat_models import ChatOpenAI | |
from llama_index import ( | |
GPTSimpleVectorIndex, | |
LLMPredictor, | |
ServiceContext, | |
download_loader, | |
) | |
# from llama_index.logger.base import LlamaLogger | |
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode | |
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter | |
from llama_index.node_parser.simple import SimpleNodeParser | |
from llama_index.prompts.chat_prompts import CHAT_REFINE_PROMPT | |
assert ( | |
os.getenv("OPENAI_API_KEY") is not None | |
), "Please set the OPENAI_API_KEY environment variable." | |
assert ( | |
os.getenv("GITHUB_TOKEN") is not None | |
), "Please set the GITHUB_TOKEN environment variable." | |
# This is a way to test loaders on different forks/branches. | |
# LLAMA_HUB_CONTENTS_URL = "https://raw.githubusercontent.com/claysauruswrecks/llama-hub/bugfix/github-repo-splitter" # noqa: E501 | |
# LOADER_HUB_PATH = "/loader_hub" | |
# LOADER_HUB_URL = LLAMA_HUB_CONTENTS_URL + LOADER_HUB_PATH | |
download_loader( | |
"GithubRepositoryReader", | |
# loader_hub_url=LOADER_HUB_URL, | |
# refresh_cache=True, | |
) | |
from llama_index.readers.llamahub_modules.github_repo import ( # noqa: E402 | |
GithubClient, | |
GithubRepositoryReader, | |
) | |
# TODO: Modify github loader to support exclude list of filenames and unblock .ipynb # noqa: E501 | |
REPOS = { | |
# NOTE: Use this to find long line filetypes to avoid: `find . -type f -exec sh -c 'awk "BEGIN { max = 0 } { if (length > max) max = length } END { printf \"%s:%d\n\", FILENAME, max }" "{}"' \; | sort -t: -k2 -nr` # noqa: E501 | |
"jerryjliu/llama_index@1b739e1fcd525f73af4a7131dd52c7750e9ca247": dict( | |
filter_directories=( | |
["docs", "examples", "gpt_index", "tests"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
filter_file_extensions=( | |
[ | |
".bat", | |
".md", | |
# ".ipynb", | |
".py", | |
".rst", | |
".sh", | |
], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
), | |
"emptycrown/llama-hub@8312da4ee8fcaf2cbbf5315a2ab8f170d102d081": dict( | |
filter_directories=( | |
["loader_hub", "tests"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
filter_file_extensions=( | |
[".py", ".md", ".txt"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
), | |
"hwchase17/langchain@d85f57ef9cbbbd5e512e064fb81c531b28c6591c": dict( | |
filter_directories=( | |
["docs", "langchain", "tests"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
filter_file_extensions=( | |
[ | |
".bat", | |
".md", | |
# ".ipynb", | |
".py", | |
".rst", | |
".sh", | |
], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
), | |
} | |
# MODEL_NAME = "gpt-3.5-turbo" | |
MODEL_NAME = "gpt-4" | |
CHUNK_SIZE_LIMIT = 512 | |
CHUNK_OVERLAP = 200 # default | |
MAX_TOKENS = None # Set to None to use model's maximum | |
EMBED_MODEL = OpenAIEmbedding(mode=OpenAIEmbeddingMode.SIMILARITY_MODE) | |
LLM_PREDICTOR = LLMPredictor( | |
llm=ChatOpenAI( | |
temperature=0.0, model_name=MODEL_NAME, max_tokens=MAX_TOKENS | |
) | |
) | |
PICKLE_DOCS_DIR = os.path.join( | |
os.path.join(os.path.join(os.path.dirname(__file__), "./"), "data"), | |
"pickled_docs", | |
) | |
# Create the directory if it does not exist | |
if not os.path.exists(PICKLE_DOCS_DIR): | |
os.makedirs(PICKLE_DOCS_DIR) | |
def load_pickle(filename): | |
"""Load the pickled embeddings""" | |
with open(os.path.join(PICKLE_DOCS_DIR, filename), "rb") as f: | |
logging.debug(f"Loading pickled embeddings from {filename}") | |
return pickle.load(f) | |
def save_pickle(obj, filename): | |
"""Save the pickled embeddings""" | |
with open(os.path.join(PICKLE_DOCS_DIR, filename), "wb") as f: | |
logging.debug(f"Saving pickled embeddings to {filename}") | |
pickle.dump(obj, f) | |
def main(args): | |
"""Run the trap.""" | |
g_docs = {} | |
for repo in REPOS.keys(): | |
logging.debug(f"Processing {repo}") | |
repo_owner, repo_name_at_sha = repo.split("/") | |
repo_name, commit_sha = repo_name_at_sha.split("@") | |
docs_filename = f"{repo_owner}-{repo_name}-{commit_sha}-docs.pkl" | |
docs_filepath = os.path.join(PICKLE_DOCS_DIR, docs_filename) | |
if os.path.exists(docs_filepath): | |
logging.debug(f"Path exists: {docs_filepath}") | |
g_docs[repo] = load_pickle(docs_filename) | |
if not g_docs.get(repo): | |
github_client = GithubClient(os.getenv("GITHUB_TOKEN")) | |
loader = GithubRepositoryReader( | |
github_client, | |
owner=repo_owner, | |
repo=repo_name, | |
filter_directories=REPOS[repo]["filter_directories"], | |
filter_file_extensions=REPOS[repo]["filter_file_extensions"], | |
verbose=args.debug, | |
concurrent_requests=10, | |
) | |
embedded_docs = loader.load_data(commit_sha=commit_sha) | |
g_docs[repo] = embedded_docs | |
save_pickle(embedded_docs, docs_filename) | |
# NOTE: set a chunk size limit to < 1024 tokens | |
service_context = ServiceContext.from_defaults( | |
llm_predictor=LLM_PREDICTOR, | |
embed_model=EMBED_MODEL, | |
node_parser=SimpleNodeParser( | |
text_splitter=TokenTextSplitter( | |
separator=" ", | |
chunk_size=CHUNK_SIZE_LIMIT, | |
chunk_overlap=CHUNK_OVERLAP, | |
backup_separators=[ | |
"\n", | |
"\n\n", | |
"\r\n", | |
"\r", | |
"\t", | |
"\\", | |
"\f", | |
"//", | |
"+", | |
"=", | |
",", | |
".", | |
"a", | |
"e", # TODO: Figure out why lol | |
], | |
) | |
), | |
# llama_logger=LlamaLogger(), # TODO: ? | |
) | |
# Collapse all the docs into a single list | |
logging.debug("Collapsing all the docs into a single list") | |
docs = [] | |
for repo in g_docs.keys(): | |
docs.extend(g_docs[repo]) | |
index = GPTSimpleVectorIndex.from_documents( | |
documents=docs, service_context=service_context | |
) | |
# Ask for CLI input in a loop | |
while True: | |
print("QUERY:") | |
query = input() | |
answer = index.query(query, refine_template=CHAT_REFINE_PROMPT) | |
print(f"ANSWER: {answer}") | |
if args.pdb: | |
import pdb | |
pdb.set_trace() | |
# Parse CLI arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--debug", | |
action="store_true", | |
default=False, | |
help="Enable debug logging.", | |
) | |
parser.add_argument( | |
"--pdb", | |
action="store_true", | |
help="Invoke PDB after each query.", | |
) | |
args = parser.parse_args() | |
if __name__ == "__main__": | |
if args.debug: | |
logging.basicConfig(level=logging.DEBUG) | |
main(args) |
@lbedner - No worries, I hope it can become helpful for you.
GPT-4 API access is in private beta, and not many people currently have access, so you'll have to make due with gpt-3.5-turbo
or others until it launches more widely.
You can check which models you have access to by navigating in your browser to https://api.openai.com/v1/models and submitting your API key into the password field for HTTP Basic Auth.
It's not intuitive, but paid "ChatGPT" (public facing web app) account is different than your "API" account, and you'll need to add additional payment and contact details at the mentioned address: https://platform.openai.com/account/billing
The documentation around using different LLMs is here: https://gpt-index.readthedocs.io/en/latest/how_to/customization/custom_llms.html
The documentation for the chat vs completion is here: https://platform.openai.com/docs/guides/chat/chat-vs-completions
The documentation on chat compatibility is here: https://platform.openai.com/docs/models/model-endpoint-compatibility
In summary, to fix the errors and get it working with models you should have access to:
- Use the 3.5-turbo model defined on https://gist.github.com/claysauruswrecks/ff68efd81b98401b44456e0f25c41f76#file-vectorize_repos-py-L94
- Make sure additional payment details are filled out for the API-specific account system.
- Maybe comment out all repos except for a small one.
- You will probably not want to change the embedding model, as it could become very expensive: https://platform.openai.com/docs/guides/embeddings/embedding-models
Hey, sorry for the late reply, life and such... So:
- I was able to get it working with your help
- thanks for the really important info on the different between my chatgpt+ and the gpt4 beta api access, definitely did not know that, but it makes so much more sense now
- I spent some days reading about embedding models, thanks to you
In short, thanks for the lesson 🦾 , and keep up the good work
This is amazing, thank you for this. A few questions/comments:
MODEL_NAME = "gpt-4"
doesn't work for me, I get this:I am continually getting the following log for a few minutes before anything happens:
I do have a paid account, so wondering what can be done about this? I assume nothing, and that it's coming from some internal code, but would be nice if this could be controlled 🤔
I get the following error when using the LLM,
text-davinci-03
:How would I change the endpoint that is used? 🤔
Thanks again for this wonderful gist, this has me on the right path!