Mix.install([
{:ortex, github: "elixir-nx/ortex", ref: "9e384971d1904ba91e5bfa49594d742a1d06cb4c"},
{:tokenizers,
github: "elixir-nx/tokenizers", override: true, ref: "20295cfdf9b6342d723b405481791ec87afa203c"},
{:exla,
github: "elixir-nx/nx",
sparse: "exla",
override: true,
ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
{:nx,
github: "elixir-nx/nx",
sparse: "nx",
override: true,
ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
{:bumblebee, github: "elixir-nx/bumblebee", ref: "8ec547243a4a1a61e45b25780a994014dc986099"},
{:kino, "~> 0.10"}
])
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)
alias VegaLite, as: Vl
This section requires Python with venv support to be installed.
See https://huggingface.co/docs/transformers/serialization?highlight=onnx.
tmp_dir = System.tmp_dir!() <> "livebook_ortex_mpnet"
File.mkdir(tmp_dir)
System.shell("python3 -m venv .venv", cd: tmp_dir, into: IO.binstream())
System.shell(
"pip3 install optimum[exporters]",
cd: tmp_dir,
env: [
{"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
{"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
],
into: IO.binstream()
)
System.shell(
"optimum-cli export onnx --model sentence-transformers/all-mpnet-base-v2 mpnet/",
cd: tmp_dir,
env: [
{"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
{"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
],
into: IO.binstream()
)
model = Ortex.load(Path.join([tmp_dir, "mpnet", "model.onnx"]))
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
defmodule SentenceTransformerOnnx do
import Nx.Defn
defn mean_pooling(model_output, attention_mask) do
input_mask_expanded = Nx.new_axis(attention_mask, -1)
model_output
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
end
def serving(model, tokenizer) do
Nx.Serving.new(Ortex.Serving, model)
|> Nx.Serving.client_preprocessing(fn inputs ->
{:ok, encodings} = Tokenizers.Tokenizer.encode_batch(tokenizer, inputs)
input_ids = for i <- encodings, do: Tokenizers.Encoding.get_ids(i)
input_mask = for i <- encodings, do: Tokenizers.Encoding.get_attention_mask(i)
inputs =
Enum.zip_with(input_ids, input_mask, fn a, b ->
{Nx.tensor(a), Nx.tensor(b)}
end)
|> Nx.Batch.stack()
{inputs, %{attention_mask: Nx.tensor(input_mask)}}
end)
|> Nx.Serving.client_postprocessing(fn {{output}, _meta}, client_info ->
mean_pooling(output, client_info.attention_mask)
end)
end
end
A playground, similar to what's on the Huggingface site: https://huggingface.co/sentence-transformers/all-mpnet-base-v2
serving = SentenceTransformerOnnx.serving(model, tokenizer)
input = Kino.Input.text("Source sentence:", default: "That is a happy person") |> Kino.render()
comparison =
Kino.Input.textarea("Sentences to compare to:",
default: """
That is a happy dog
That is a very happy person
Today is a sunny day
"""
)
|> Kino.render()
input_text = Kino.Input.read(input)
if byte_size(input_text) == 0 do
Kino.interrupt!(:normal, "Please enter source sentence")
end
comparison_texts =
Kino.Input.read(comparison)
|> String.split("\n", trim: true)
if comparison_texts == [] do
Kino.interrupt!(:normal, "Please enter comparison sentences (one per line)")
end
input = Nx.Serving.run(serving, [input_text])
comparison =
Nx.Serving.run(serving, comparison_texts)
sim = Bumblebee.Utils.Nx.cosine_similarity(input, comparison)
for {v, i} <- Enum.with_index(comparison_texts) do
{v, Nx.to_number(sim[0][i])}
end