Last active
June 3, 2023 03:58
-
-
Save j40903272/bf14f8f0dcd92cea4eea60a816dfcacf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Union, List, Dict, Tuple | |
from pydantic import Field | |
import ast | |
import astunparse | |
import logging | |
import pandas as pd | |
from langchain.utilities import GoogleSearchAPIWrapper | |
from langchain.tools import Tool | |
from langchain.python import PythonREPL | |
from langchain.tools.base import BaseTool | |
from langchain.document_loaders import WebBaseLoader, SeleniumURLLoader | |
from langchain.document_loaders import UnstructuredPDFLoader | |
# google search, pdf parser, web crawler, similarity search, python shell | |
def prompts(name: str, description: str): | |
def decorator(func): | |
func.name = name | |
func.description = description | |
return func | |
return decorator | |
search = GoogleSearchAPIWrapper() | |
googleSearchTool = Tool( | |
name = "Google Search", | |
description="Search Google for recent results.", | |
func=search.run | |
) | |
class pdfParserTool: | |
@prompts(name="pdf parser", | |
description="The input to this tool should be a string, describing the pdf path. ") | |
def inference(self, path: str): | |
loader = UnstructuredPDFLoader(path) | |
pages = loader.load_and_split() | |
docs[file] = [ | |
p.page_content for p in pages | |
] | |
return docs | |
class similaritySearchTool: | |
def __init__(self, site: str): | |
embeddings = OpenAIEmbeddings( | |
deployment="ai-explore-text-embedding-ada-002", | |
model="text-embedding-ada-002", | |
) | |
index_file = f'{site}.faiss_index' | |
if os.path.exists(product_index): | |
print('Load product faiss index') | |
product_db = FAISS.load_local(index_file, embeddings) | |
else: | |
print('Create faiss index') | |
df = pd.read_csv(datafeed_path) | |
loader = DataFrameLoader(df, page_content_column="Title") | |
documents = loader.load() | |
product_db = FAISS.from_documents(documents, embeddings) | |
product_db.save_local(index_file) | |
self.product_db = product_db | |
@prompts(name="Similarity search", | |
description="The input to this tool should be a string, describing the product.") | |
def inference(self, query): | |
docs = self.product_db.similarity_search_with_score(query, k=10) | |
related = {} | |
for doc in docs: | |
if doc[0].page_content not in related: | |
similarity = 1-doc[1] | |
related[doc[0].page_content] = similarity | |
res = "" | |
for e, (page_content, score) in enumerate(related.items()): | |
res += f"{e}. product: {page_content} relevance: {score})\n" | |
return res | |
class pythonShellTool: | |
global_dict = {} | |
local_dict = {} | |
def __init__(self): | |
# add scope variable here | |
pass | |
@prompts( | |
name="python_repl_ast", | |
description="This is a python shell. Use this to execute python commands. " | |
"Input should be a valid python command. " | |
"The output of the command should not be too long." | |
) | |
def inference(self, query: str): | |
try: | |
tree = ast.parse(query) | |
module = ast.Module(tree.body[:-1], type_ignores=[]) | |
exec(astunparse.unparse(module), self.globals, self.locals) # type: ignore | |
module_end = ast.Module(tree.body[-1:], type_ignores=[]) | |
module_end_str = astunparse.unparse(module_end) # type: ignore | |
try: | |
return eval(module_end_str, self.globals, self.locals)[:1024] | |
except Exception: | |
exec(module_end_str, self.globals, self.locals) | |
return "" | |
except Exception as e: | |
return str(e) | |
class webCrawlerTool: | |
def __init__(self): | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.prompts import PromptTemplate | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
prompt_template = """ | |
{text} | |
""" | |
PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"]) | |
self.chain = load_summarize_chain(llm, chain_type="stuff", prompt=PROMPT) | |
@prompts(name="Guide", | |
description="This is a web crawler for certain product." | |
"The output of this tool would be a string." | |
"The input to this tool should be a string, which is a keyword describing the product." | |
) | |
def inference(self, query: str): | |
loader = MySeleniumURLLoader( | |
[f'https://my-best.tw/search_contents?q={query}'], | |
executable_path='./chromedriver' | |
) | |
docs = loader.load() | |
if docs: | |
# get the best search reuslt | |
best_match = docs[0] | |
print(best_match) | |
# crawl web | |
loader = WebBaseLoader(best_match.metadata['href']) | |
docs = loader.load() | |
docs[0].page_content = docs[0].page_content[:1024] | |
guide = self.chain.run(docs) | |
return str([i+'\n' for i in guide.split('\n') if i.strip()]) + '請把這段摘要建議給user.' | |
else: | |
return "Cannot find any additional information. Do not use Guide tool again." | |
class MySeleniumURLLoader(SeleniumURLLoader): | |
def load(self) -> list: | |
docs = [] | |
driver = self._get_driver() | |
for url in self.urls: | |
try: | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.support import expected_conditions as EC | |
from selenium.webdriver.common.by import By | |
from selenium.common.exceptions import TimeoutException | |
try: | |
driver.get(url) | |
# customize your web content here | |
# =========== | |
elements = WebDriverWait(driver, 10).until( | |
EC.presence_of_all_elements_located((By.CLASS_NAME, 'kc-list-items__item-container--pc')) | |
) | |
for e in elements: | |
metadata = {"href": e.get_attribute('href')} | |
text = e.text.split('\n')[-1] | |
docs.append(Document(page_content=e.text, metadata=metadata)) | |
return docs | |
# =========== | |
except TimeoutException: | |
print('Search result not found') | |
except Exception as e: | |
if self.continue_on_failure: | |
print(f"Error fetching or processing {url}, exception: {e}") | |
else: | |
raise e | |
driver.quit() | |
return docs | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment