From 26dcce789356b7c992d896ffaf2c0bdf24e46d02 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 25 Apr 2023 22:13:22 -0700 Subject: [PATCH] Added Vector struct --- lib/pgvector.ex | 56 +++++++++++++++++++++++++++++++ lib/pgvector/extensions/vector.ex | 31 ++--------------- lib/pgvector/vector.ex | 11 ++++++ test/ecto_test.exs | 2 +- test/postgrex_test.exs | 4 +-- test/vector_test.exs | 18 ++++++++++ 6 files changed, 91 insertions(+), 31 deletions(-) create mode 100644 lib/pgvector.ex create mode 100644 lib/pgvector/vector.ex create mode 100644 test/vector_test.exs diff --git a/lib/pgvector.ex b/lib/pgvector.ex new file mode 100644 index 0000000..f3a5b49 --- /dev/null +++ b/lib/pgvector.ex @@ -0,0 +1,56 @@ +defmodule Pgvector do + def vector(list) when is_list(list) do + dim = list |> length() + bin = Enum.map(list, fn v -> <> end) |> :erlang.list_to_bitstring() + data = <> + %Pgvector.Vector{data: data} + end + + if Code.ensure_loaded?(Nx) do + def vector(t) when is_struct(t, Nx.Tensor) do + if Nx.rank(t) != 1 do + raise ArgumentError, "expected rank to be 1" + end + dim = t |> Nx.size() + bin = t |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_to_big() |> :erlang.list_to_bitstring() + data = <> + %Pgvector.Vector{data: data} + end + + defp f32_to_big(bin) do + if System.endianness() == :big do + bin + else + for <>, do: <> + end + end + end + + def from_binary(binary) when is_binary(binary) do + %Pgvector.Vector{data: binary} + end + + def to_binary(vector) when is_struct(vector, Pgvector.Vector) do + vector.data + end + + def to_list(vector) when is_struct(vector, Pgvector.Vector) do + <> = vector.data + for <>, do: v + end + + if Code.ensure_loaded?(Nx) do + def to_tensor(vector) when is_struct(vector, Pgvector.Vector) do + <> = vector.data + bin |> big_to_f32() |> :erlang.list_to_bitstring() |> Nx.from_binary(:f32) + end + + defp big_to_f32(bin) do + if System.endianness() == :big do + bin + else + for <>, do: <> + end + end + end +end diff --git a/lib/pgvector/extensions/vector.ex b/lib/pgvector/extensions/vector.ex index 288d992..06ad3bb 100644 --- a/lib/pgvector/extensions/vector.ex +++ b/lib/pgvector/extensions/vector.ex @@ -10,40 +10,15 @@ defmodule Pgvector.Extensions.Vector do def encode(_) do quote do vec -> - data = unquote(__MODULE__).encode_vector(vec) + data = vec |> Pgvector.vector() |> Pgvector.to_binary() [<> | data] end end def decode(_) do quote do - <<_len::int32(), dim::uint16, 0::uint16, bin::binary-size(dim)-unit(32)>> -> - for <>, do: v - end - end - - def encode_vector(list) when is_list(list) do - dim = list |> length() - bin = for v <- list, do: <> - [<> | bin] - end - - if Code.ensure_loaded?(Nx) do - def encode_vector(tensor) when is_struct(tensor, Nx.Tensor) do - if Nx.rank(tensor) != 1 do - raise ArgumentError, "expected rank to be 1" - end - dim = tensor |> Nx.size() - bin = tensor |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_to_big() - [<> | bin] - end - - defp f32_to_big(bin) do - if System.endianness() == :big do - bin - else - for <>, do: <> - end + <> -> + data |> Pgvector.from_binary() end end end diff --git a/lib/pgvector/vector.ex b/lib/pgvector/vector.ex new file mode 100644 index 0000000..e8499fb --- /dev/null +++ b/lib/pgvector/vector.ex @@ -0,0 +1,11 @@ +defmodule Pgvector.Vector do + defstruct [:data] +end + +defimpl Inspect, for: Pgvector.Vector do + import Inspect.Algebra + + def inspect(vec, opts) do + concat(["Vector(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"]) + end +end diff --git a/test/ecto_test.exs b/test/ecto_test.exs index 52b70ea..963d862 100644 --- a/test/ecto_test.exs +++ b/test/ecto_test.exs @@ -33,7 +33,7 @@ defmodule EctoTest do items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5) assert Enum.map(items, fn v -> v.id end) == [1, 3, 2] - assert Enum.map(items, fn v -> v.embedding end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]] + assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]] items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5) assert Enum.map(items, fn v -> v.id end) == [2, 3, 1] diff --git a/test/postgrex_test.exs b/test/postgrex_test.exs index e63e7ac..a6629b2 100644 --- a/test/postgrex_test.exs +++ b/test/postgrex_test.exs @@ -26,7 +26,7 @@ defmodule PostgrexTest do assert ["id", "embedding"] == result.columns assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2] - assert Enum.map(result.rows, fn v -> Enum.at(v, 1) end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] + assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] Postgrex.query!(pid, "CREATE INDEX my_index ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)", []) end @@ -34,7 +34,7 @@ defmodule PostgrexTest do test "tensor", %{pid: pid} = _context do embedding = Nx.tensor([1.0, 2.0, 3.0]) result = Postgrex.query!(pid, "SELECT $1::vector", [embedding]) - assert result.rows == [[Nx.to_list(embedding)]] + assert Enum.map(result.rows, fn v -> Enum.at(v, 0) |> Pgvector.to_tensor() end) == [embedding] end test "tensor rank", %{pid: pid} = _context do diff --git a/test/vector_test.exs b/test/vector_test.exs new file mode 100644 index 0000000..5555559 --- /dev/null +++ b/test/vector_test.exs @@ -0,0 +1,18 @@ +defmodule VectorTest do + use ExUnit.Case + + test "list" do + list = [1.0, 2.0, 3.0] + assert list == (list |> Pgvector.vector() |> Pgvector.to_list()) + end + + test "tensor" do + tensor = Nx.tensor([1.0, 2.0, 3.0]) + assert tensor == (tensor |> Pgvector.vector() |> Pgvector.to_tensor()) + end + + test "inspect" do + vector = Pgvector.vector([1, 2, 3]) + assert "Vector([1.0, 2.0, 3.0])" == inspect(vector) + end +end