Skip to content

Instantly share code, notes, and snippets.

@BlackHC
Last active May 12, 2023 15:08
Show Gist options
  • Save BlackHC/49a37aaa6e9f3e31928596ce477897ad to your computer and use it in GitHub Desktop.
Save BlackHC/49a37aaa6e9f3e31928596ce477897ad to your computer and use it in GitHub Desktop.
Until LangChain adds support for a ChatOpenAI cache, here is a drop-in class that adds support for it
# PoC to cache prompts. Drop in your code.
# Andreas 'blackhc' Kirsch, 2023
from typing import List, Optional
import langchain
from langchain import OpenAI
from langchain.cache import SQLiteCache
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
Generation
)
langchain.llm_cache = SQLiteCache(".chat.langchain.db")
class CachedChatOpenAI(ChatOpenAI):
def _generate(self, messages: List[BaseMessage], *args, **kwargs) -> ChatResult:
# NOTE: the cache does currently not respect additional arguments beyond the messages.
messages_prompt = repr(messages)
if langchain.llm_cache:
results = langchain.llm_cache.lookup(messages_prompt, self.model_name)
if results:
assert len(results) == 1
result: Generation = results[0]
chat_result = ChatResult(
generations=[ChatGeneration(message=AIMessage(content=result.text))],
llm_output=result.generation_info)
return chat_result
chat_result = super()._generate(messages, *args, **kwargs)
if langchain.llm_cache:
assert len(chat_result.generations) == 1
result = Generation(
text=chat_result.generations[0].message.content,
generation_info=chat_result.llm_output
)
langchain.llm_cache.update(messages_prompt, self.model_name, [result])
return chat_result
chat_model = CachedChatOpenAI(max_tokens=512, model_kwargs=dict(temperature=0.))
@drorhilman
Copy link

from typing import List, Optional
from langchain.schema import (
    AIMessage,
    BaseMessage,
    ChatGeneration,
    ChatResult,
    Generation
)

@BlackHC
Copy link
Author

BlackHC commented May 11, 2023

Thanks! Updated 🤗

@Erliz
Copy link

Erliz commented May 12, 2023

For v0.0.154 version, you should add:

    run_manager: Optional[CallbackManagerForLLMRun] = None,

cause of https://github.com/hwchase17/langchain/blob/258c3198559da5844be3f78680f42b2930e5b64b/langchain/chat_models/openai.py#L270

@BlackHC
Copy link
Author

BlackHC commented May 12, 2023

Thanks! I've replaced additional arguments with *args, **kwargs now. However, the cache does not respect stop etc so beware relying on that 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment