Skip to content

Instantly share code, notes, and snippets.

@j40903272
Last active June 3, 2023 03:58
Show Gist options
  • Save j40903272/bf14f8f0dcd92cea4eea60a816dfcacf to your computer and use it in GitHub Desktop.
Save j40903272/bf14f8f0dcd92cea4eea60a816dfcacf to your computer and use it in GitHub Desktop.
from typing import Optional, Union, List, Dict, Tuple
from pydantic import Field
import ast
import astunparse
import logging
import pandas as pd
from langchain.utilities import GoogleSearchAPIWrapper
from langchain.tools import Tool
from langchain.python import PythonREPL
from langchain.tools.base import BaseTool
from langchain.document_loaders import WebBaseLoader, SeleniumURLLoader
from langchain.document_loaders import UnstructuredPDFLoader
# google search, pdf parser, web crawler, similarity search, python shell
def prompts(name: str, description: str):
def decorator(func):
func.name = name
func.description = description
return func
return decorator
search = GoogleSearchAPIWrapper()
googleSearchTool = Tool(
name = "Google Search",
description="Search Google for recent results.",
func=search.run
)
class pdfParserTool:
@prompts(name="pdf parser",
description="The input to this tool should be a string, describing the pdf path. ")
def inference(self, path: str):
loader = UnstructuredPDFLoader(path)
pages = loader.load_and_split()
docs[file] = [
p.page_content for p in pages
]
return docs
class similaritySearchTool:
def __init__(self, site: str):
embeddings = OpenAIEmbeddings(
deployment="ai-explore-text-embedding-ada-002",
model="text-embedding-ada-002",
)
index_file = f'{site}.faiss_index'
if os.path.exists(product_index):
print('Load product faiss index')
product_db = FAISS.load_local(index_file, embeddings)
else:
print('Create faiss index')
df = pd.read_csv(datafeed_path)
loader = DataFrameLoader(df, page_content_column="Title")
documents = loader.load()
product_db = FAISS.from_documents(documents, embeddings)
product_db.save_local(index_file)
self.product_db = product_db
@prompts(name="Similarity search",
description="The input to this tool should be a string, describing the product.")
def inference(self, query):
docs = self.product_db.similarity_search_with_score(query, k=10)
related = {}
for doc in docs:
if doc[0].page_content not in related:
similarity = 1-doc[1]
related[doc[0].page_content] = similarity
res = ""
for e, (page_content, score) in enumerate(related.items()):
res += f"{e}. product: {page_content} relevance: {score})\n"
return res
class pythonShellTool:
global_dict = {}
local_dict = {}
def __init__(self):
# add scope variable here
pass
@prompts(
name="python_repl_ast",
description="This is a python shell. Use this to execute python commands. "
"Input should be a valid python command. "
"The output of the command should not be too long."
)
def inference(self, query: str):
try:
tree = ast.parse(query)
module = ast.Module(tree.body[:-1], type_ignores=[])
exec(astunparse.unparse(module), self.globals, self.locals) # type: ignore
module_end = ast.Module(tree.body[-1:], type_ignores=[])
module_end_str = astunparse.unparse(module_end) # type: ignore
try:
return eval(module_end_str, self.globals, self.locals)[:1024]
except Exception:
exec(module_end_str, self.globals, self.locals)
return ""
except Exception as e:
return str(e)
class webCrawlerTool:
def __init__(self):
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
prompt_template = """
{text}
"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
self.chain = load_summarize_chain(llm, chain_type="stuff", prompt=PROMPT)
@prompts(name="Guide",
description="This is a web crawler for certain product."
"The output of this tool would be a string."
"The input to this tool should be a string, which is a keyword describing the product."
)
def inference(self, query: str):
loader = MySeleniumURLLoader(
[f'https://my-best.tw/search_contents?q={query}'],
executable_path='./chromedriver'
)
docs = loader.load()
if docs:
# get the best search reuslt
best_match = docs[0]
print(best_match)
# crawl web
loader = WebBaseLoader(best_match.metadata['href'])
docs = loader.load()
docs[0].page_content = docs[0].page_content[:1024]
guide = self.chain.run(docs)
return str([i+'\n' for i in guide.split('\n') if i.strip()]) + '請把這段摘要建議給user.'
else:
return "Cannot find any additional information. Do not use Guide tool again."
class MySeleniumURLLoader(SeleniumURLLoader):
def load(self) -> list:
docs = []
driver = self._get_driver()
for url in self.urls:
try:
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.by import By
from selenium.common.exceptions import TimeoutException
try:
driver.get(url)
# customize your web content here
# ===========
elements = WebDriverWait(driver, 10).until(
EC.presence_of_all_elements_located((By.CLASS_NAME, 'kc-list-items__item-container--pc'))
)
for e in elements:
metadata = {"href": e.get_attribute('href')}
text = e.text.split('\n')[-1]
docs.append(Document(page_content=e.text, metadata=metadata))
return docs
# ===========
except TimeoutException:
print('Search result not found')
except Exception as e:
if self.continue_on_failure:
print(f"Error fetching or processing {url}, exception: {e}")
else:
raise e
driver.quit()
return docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment