Mix.install([
{:ortex, github: "elixir-nx/ortex", ref: "9e384971d1904ba91e5bfa49594d742a1d06cb4c"},
{:tokenizers,
github: "elixir-nx/tokenizers", override: true, ref: "20295cfdf9b6342d723b405481791ec87afa203c"},
{:exla,
github: "elixir-nx/nx",
sparse: "exla",
override: true,
ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
{:nx,
github: "elixir-nx/nx",
sparse: "nx",
override: true,
ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
{:bumblebee, github: "elixir-nx/bumblebee", ref: "8ec547243a4a1a61e45b25780a994014dc986099"},
{:kino, "~> 0.10"},
{:vega_lite, "~> 0.1.7"},
{:kino_vega_lite, "~> 0.1.9"}
])
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)
alias VegaLite, as: Vl
This section requires Python with venv support to be installed.
See https://huggingface.co/docs/transformers/serialization?highlight=onnx.
tmp_dir = System.tmp_dir!() <> "livebook_ortex_sentence_embeddings"
File.mkdir(tmp_dir)
System.shell("python3 -m venv .venv", cd: tmp_dir, into: IO.binstream())
System.shell(
"pip3 install optimum[exporters]",
cd: tmp_dir,
env: [
{"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
{"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
],
into: IO.binstream()
)
System.shell(
"optimum-cli export onnx --model sentence-transformers/all-MiniLM-L6-v2 minilm/",
cd: tmp_dir,
env: [
{"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
{"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
],
into: IO.binstream()
)
model = Ortex.load(Path.join([tmp_dir, "minilm", "model.onnx"]))
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
defmodule SentenceTransformerOnnx do
import Nx.Defn
defn mean_pooling(model_output, attention_mask) do
input_mask_expanded = Nx.new_axis(attention_mask, -1)
model_output
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
end
def serving(model, tokenizer) do
Nx.Serving.new(Ortex.Serving, model)
|> Nx.Serving.client_preprocessing(fn inputs ->
{:ok, encodings} = Tokenizers.Tokenizer.encode_batch(tokenizer, inputs)
# get the maximum sequence length from the input by looking at the attention mask
max_length =
encodings
|> Enum.map(&Tokenizers.Encoding.get_attention_mask/1)
|> Enum.map(fn tensor -> Enum.sum(tensor) end)
|> Enum.max(fn -> nil end)
encodings =
if max_length do
for e <- encodings, do: Tokenizers.Encoding.truncate(e, max_length)
else
encodings
end
input_ids = for i <- encodings, do: Tokenizers.Encoding.get_ids(i)
input_mask = for i <- encodings, do: Tokenizers.Encoding.get_attention_mask(i)
token_type_ids = for i <- encodings, do: Tokenizers.Encoding.get_type_ids(i)
inputs =
Enum.zip_with([input_ids, input_mask, token_type_ids], fn [a, b, c] ->
{Nx.tensor(a), Nx.tensor(b), Nx.tensor(c)}
end)
|> Nx.Batch.stack()
{inputs, %{attention_mask: Nx.tensor(input_mask)}}
end)
|> Nx.Serving.client_postprocessing(fn {{output}, _meta}, client_info ->
mean_pooling(output, client_info.attention_mask)
end)
end
end
serving = SentenceTransformerOnnx.serving(model, tokenizer)
input = Kino.Input.text("Source sentence:") |> Kino.render()
comparison = Kino.Input.textarea("Sentences to compare to:") |> Kino.render()
input_text = Kino.Input.read(input)
if byte_size(input_text) == 0 do
Kino.interrupt!(:normal, "Please enter source sentence")
end
comparison_texts =
Kino.Input.read(comparison)
|> String.split("\n", trim: true)
if comparison_texts == [] do
Kino.interrupt!(:normal, "Please enter comparison sentences (one per line)")
end
input = Nx.Serving.run(serving, [input_text])
comparison =
Nx.Serving.run(serving, comparison_texts)
sim = Bumblebee.Utils.Nx.cosine_similarity(input, comparison)
for {v, i} <- Enum.with_index(comparison_texts) do
{v, Nx.to_number(sim[0][i])}
end
|> Enum.sort_by(&elem(&1, 1), :desc)
defmodule ConcurrentBench do
def run(fun, concurrency \\ System.schedulers_online(), timeout \\ 10_000) do
# use an erlang counter to count the number of function invocations
counter = :counters.new(1, [:write_concurrency])
# returns time in microseconds
{taken, _} =
:timer.tc(fn ->
tasks =
for _i <- 1..concurrency do
Task.async(fn ->
Stream.repeatedly(fn ->
fun.()
# only count after the function ran successfully
:counters.add(counter, 1, 1)
end)
|> Stream.run()
end)
end
results = Task.yield_many(tasks, timeout)
# kill all processes
Enum.map(results, fn {task, res} ->
res || Task.shutdown(task, :brutal_kill)
end)
end)
runs = :counters.get(counter, 1)
ips = runs / (taken / 1_000_000)
%{runs: runs, ips: ips}
end
end
text =
"Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr"
splitted = String.split(text, " ", trim: true)
texts =
for i <- 1..length(splitted) do
Enum.take(splitted, i)
|> Enum.join(" ")
end
sequence_lengths =
Enum.map(texts, fn text ->
{:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)
Tokenizers.Encoding.get_attention_mask(encoding) |> Enum.sum()
end)
defmodule BenchTest do
def run(serving, text, batch_size, batch_timeout, concurrency, timeout \\ 10_000) do
{:ok, pid} =
Kino.start_child(
{Nx.Serving,
serving: serving, name: MyServing, batch_size: batch_size, batch_timeout: batch_timeout}
)
mod = Module.concat(Bench, "Test#{System.unique_integer()}")
defmodule mod do
def run(text, concurrency, timeout) do
ConcurrentBench.run(
fn ->
Nx.Serving.batched_run(MyServing, [text])
end,
concurrency,
timeout
)
end
end
result =
mod.run(text, concurrency, timeout)
|> tap(fn _ ->
:code.purge(mod)
:code.delete(mod)
end)
|> IO.inspect(
label: "batch: #{batch_size}; concurrency: #{concurrency}, timeout: #{batch_timeout}"
)
Kino.terminate_child(pid)
result
end
end
chart =
Vl.new(width: 1280, height: 720)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "sequence_length", type: :quantitative)
|> Vl.encode_field(:y, "ips", type: :quantitative)
|> Kino.VegaLite.new()
for {text, sequence_length} <- Enum.zip(texts, sequence_lengths) do
batch_size = 64
batch_timeout = 50
concurrency = 64
%{ips: ips} = BenchTest.run(serving, text, batch_size, batch_timeout, concurrency, 5_000)
IO.inspect("sequence_length: #{sequence_length}, ips: #{ips}")
Kino.VegaLite.push(chart, %{ips: ips, sequence_length: sequence_length})
end