Skip to content

Instantly share code, notes, and snippets.

@7shi
Last active September 15, 2024 08:25
Show Gist options
  • Save 7shi/c589bba6e739304a5098c8a3f2f55cc8 to your computer and use it in GitHub Desktop.
Save 7shi/c589bba6e739304a5098c8a3f2f55cc8 to your computer and use it in GitHub Desktop.
[py] test Ruri text embeddings
import sys
args = sys.argv[1:]
if len(args) != 1:
print(f"Usage: {sys.argv[0]} <textfile>", file=sys.stderr)
sys.exit(1)
import os, torch, safetensors.torch
from sentence_transformers import SentenceTransformer
textfile = args[0]
tensorfile = os.path.splitext(textfile)[0] + ".safetensors"
# Download from the 🤗 Hub
model = SentenceTransformer("cl-nagoya/ruri-base")
# Don't forget to add the prefix "クエリ: " for query-side or "文章: " for passage-side texts.
with open(textfile, "r") as f:
lines = [l for line in f if (l := line.strip())]
test = model.encode(["文章: test"], convert_to_tensor=True)[0]
tensor = torch.zeros(len(lines), len(test), dtype=torch.float32)
for i, line in enumerate(lines):
print(f"{i+1} / {len(lines)} {line}")
sentences = ["文章: " + line]
tensor[i, :] = model.encode(sentences, convert_to_tensor=True)[0]
safetensors.torch.save_file({"lines": tensor}, tensorfile)
import sys
args = sys.argv[1:]
if len(args) != 1:
print(f"Usage: {sys.argv[0]} <textfile>", file=sys.stderr)
sys.exit(1)
import os, torch, torch.nn.functional as F, safetensors.torch
from sentence_transformers import SentenceTransformer
textfile = args[0]
tensorfile = os.path.splitext(textfile)[0] + ".safetensors"
# Download from the 🤗 Hub
model = SentenceTransformer("cl-nagoya/ruri-base")
with open(textfile, "r") as f:
lines = [l for line in f if (l := line.strip())]
tensor = safetensors.torch.load_file(tensorfile)["lines"]
# Don't forget to add the prefix "クエリ: " for query-side or "文章: " for passage-side texts.
while True:
print()
try:
q = input("> ")
except:
print()
break
sentences = ["クエリ: " + q]
embeddings = model.encode(sentences, convert_to_tensor=True)
similarities = F.cosine_similarity(tensor, embeddings, dim=1)
for i, (value, index) in enumerate(zip(*torch.topk(similarities, k=10))):
v, idx = value.item(), index.item()
print(f"{i+1:2d}: {v:.5f} {idx + 1:4d} {lines[idx]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment