Skip to content

Instantly share code, notes, and snippets.

@lambdaofgod
Created November 22, 2023 09:32
Show Gist options
  • Save lambdaofgod/d1aae99d037878a55accf45da3113dc6 to your computer and use it in GitHub Desktop.
Save lambdaofgod/d1aae99d037878a55accf45da3113dc6 to your computer and use it in GitHub Desktop.
Optimal transport in Python
import transformers
from sentence_transformers import SentenceTransformer


checked_symptoms = ["ból dupy", "jebie mnie w krzyżu"]
reference_symptoms = [["ból lędźwi"], ["ból gardła", "ból dupy"]]

model = SentenceTransformer('sdadas/st-polish-paraphrase-from-distilroberta')
checked_embeddings = model.encode(checked_symptoms)
reference_embeddings = [model.encode(rs) for rs in reference_symptoms]
import ot
import numpy as np
from sklearn.metrics.pairwise import cosine_distances
from pydantic import BaseModel, Field
from typing import Callable


X_query = checked_embeddings
X_compared = reference_embeddings[1]


class MultiVectorDistanceMetric(BaseModel):
    """
    uses optimal transport to calculate distance between
    two sequences of vectors

    see https://pythonot.github.io/quickstart.html
    """
    ot_method: Callable = Field(default=ot.sinkhorn2)
    reg: float = Field(default=1e-2)
    vector_distance: Callable = cosine_distances

    def __call__(self, X_query, X_compared):
        D = self.vector_distance(X_query, X_compared)
        w_query = np.ones(len(X_query)) / len(X_query)
        w_compared = np.ones(len(X_compared)) / len(X_compared)
        return self.ot_method(w_query, w_compared, D, self.reg)

sinkhorn_distance = MultiVectorDistanceMetric()

sinkhorn_distance(checked_embeddings, reference_embeddings[1])
0.39404594056062947
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment