Skip to content

Instantly share code, notes, and snippets.

@PttCodingMan
Created July 15, 2023 05:51
Show Gist options
  • Save PttCodingMan/7cdb47d857c0cad987ade593ae86a765 to your computer and use it in GitHub Desktop.
Save PttCodingMan/7cdb47d857c0cad987ade593ae86a765 to your computer and use it in GitHub Desktop.
import logging
import re
from queue import PriorityQueue
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
from post_walker import post_walker
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(message)s',
datefmt='%m%d %H:%M:%S',
)
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
related_top = 5
block_post_flags = ['top: true', 'hidden: true']
class Post:
def __init__(self, title: str, abbrlink: str, tags: list[str], raw_data: str):
# logger.info('title %s', title)
# logger.info('abbrlink %s', abbrlink)
self.title = title
self.link = abbrlink
self.tags = tags
self.analytics_content = ' '.join(self.tags) + ' ' + self.title + ' ' + raw_data[raw_data.rfind('---') + 3:].strip()
if '## 相關文章' in self.analytics_content:
self.analytics_content = self.analytics_content[:self.analytics_content.rfind('## 相關文章')]
self.related_posts = []
self.raw_data = raw_data
self.related_score = 0
def add_related_post(self, related_post):
self.related_posts.put(
(-cosine_similarity(self.embedding, related_post.embedding)[0][0], related_post.link))
def count_must_related_post():
sentences = [p.analytics_content for p in posts.values()]
embedding = model.encode(sentences, convert_to_tensor=False)
cosine_scores = util.cos_sim(embedding, embedding)
for i in range(len(sentences)):
priority_queue = PriorityQueue()
for j in range(len(sentences)):
if i == j:
continue
priority_queue.put((-cosine_scores[i][j].item(), j))
current_post = posts[sentences[i]]
logger.info('post %s top %s', current_post.title, related_top)
for _ in range(related_top):
if priority_queue.empty():
break
score, index = priority_queue.get()
related_post = posts[sentences[index]]
logger.info('related post %s score %s', related_post.title, -score)
current_post.related_posts.append(related_post)
########
posts = dict()
raw_posts = dict()
def is_block_post(raw_data: str) -> bool:
for flag in block_post_flags:
if flag in raw_data:
return True
return False
def collect_posts(raw_data: str) -> str:
if is_block_post(raw_data):
return raw_data
global posts
title_match = re.search(r"title: (.+)", raw_data)
abbrlink_match = re.search(r"abbrlink: (.+)", raw_data)
title = None
if title_match:
title = title_match.group(1).strip()
link = None
if abbrlink_match:
link = abbrlink_match.group(1).strip()
if link.startswith("'"):
link = link[1:-1]
tags_pattern = r'tags:\n((?:\s*-\s*.+\n)*)'
tags_match = re.search(tags_pattern, raw_data)
tags = None
if tags_match:
tags_string = tags_match.group(1)
tags = [tag.strip().lower() for tag in re.findall(r'-\s*(.+)', tags_string)]
if '--' in tags:
tags.remove('--')
tags.sort()
if title is None or link is None or tags is None:
return raw_data
post = Post(title, link, tags, raw_data)
posts[post.analytics_content] = post
raw_posts[post.raw_data] = post
return raw_data
def add_related_post(raw_data: str) -> str:
if raw_data not in raw_posts:
return raw_data
current_post = raw_posts[raw_data]
if '## 相關文章' in raw_data:
raw_data = raw_data[:raw_data.rfind('## 相關文章')].strip()
if is_block_post(raw_data):
return raw_data
append_content = '\n\n## 相關文章\n\n'
for related_post in current_post.related_posts:
append_content += f'- [{related_post.title}](/{related_post.link})\n'
raw_data += append_content
return raw_data
if __name__ == '__main__':
logger = logging.getLogger(__name__)
post_walker(collect_posts)
count_must_related_post()
post_walker(add_related_post)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment