From 2f5071464471257f98be98f558a278d8bb2eb2bc Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:50:32 -0300 Subject: [PATCH 01/14] fix: make exla build resilient to stale upgrades (#1548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim Co-authored-by: Jonatan Kłosko --- exla/Makefile | 4 +- exla/README.md | 20 +++++++ exla/lib/exla/nif.ex | 15 +++++- exla/mix.exs | 52 ++++++++++++++++--- exla/test/exla/device_memory_sharing_test.exs | 9 ++-- 5 files changed, 86 insertions(+), 14 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index a371447d101..695c7f9409d 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -8,8 +8,8 @@ XLA_EXTENSION_LIB = $(XLA_EXTENSION_DIR)/lib XLA_INCLUDE_PATH = $(XLA_EXTENSION_DIR)/include # Cache configuration -EXLA_CACHE_SO = cache/libexla.so -EXLA_CACHE_OBJ_DIR = cache/objs +EXLA_CACHE_SO = cache/$(EXLA_VERSION)/libexla.so +EXLA_CACHE_OBJ_DIR = cache/$(EXLA_VERSION)/objs # Private configuration EXLA_DIR = c_src/exla diff --git a/exla/README.md b/exla/README.md index 3091555796a..2ffe144ff72 100644 --- a/exla/README.md +++ b/exla/README.md @@ -48,6 +48,26 @@ EXLA relies on the [XLA](https://github.com/elixir-nx/xla) package to provide th For cross-compilation, you need to [set your `XLA_TARGET_PLATFORM` variable](https://github.com/elixir-nx/xla#xla_target_platform) to the correct target platform value (i.e. `aarch64-linux-gnu` for the Raspberry Pi 4). +## Troubleshooting + +EXLA uses NIFs (C-interface code called from Elixir) for part of its functionality. +If for any reason these fail to compile or load, troubleshooting can be tricky. + +We recommend following the steps below: + + 1. If the error appeared after upgrading EXLA, ensure that you have the proper dependency versions given by [XLA](https://github.com/elixir-nx/xla). Afterwards, compile with `mix compile` after setting `EXLA_FORCE_REBUILD` to clean up cached files: + * `EXLA_FORCE_REBUILD=partial`: Removes the only the libexla.so caches (both local and global ones). + * `EXLA_FORCE_REBUILD=true`: Removes the libexla.so caches but also removes the intermediate `.o` compilation artifacts retained from previous builds. + + Additional notes on compilation: + * Besides the XLA dependency versions, ensuring `gcc` (or your compiler of choice), `libc` and `make` are compatible is also important. + * Remember to save the compilation logs from this step for further debugging. + * It is a good idea to save the `cache//libexla.so` file so that the team can inspect its contents if needed. + 2. If the error persists, look for the `** (RuntimeError) Failed to load NIF library.` exception on application start-up. + This exception should provide more information on what's the issue when loading the NIF. Share these logs in an issue on GitHub + so that the Nx team can investigate further. + + ## Contributing ### Building locally diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index be0567cc0aa..023a0bcbd21 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -4,7 +4,20 @@ defmodule EXLA.NIF do def __on_load__ do path = :filename.join(:code.priv_dir(:exla), ~c"libexla") - :erlang.load_nif(path, 0) + + case :erlang.load_nif(path, 0) do + :ok -> + :ok + + {:error, {reason, text}} -> + raise """ + Failed to load NIF library. + Follow the steps in the :exla README Troubleshooting section for more information. + + #{inspect(reason)} + #{text} + """ + end end def mlir_new_thread_pool(_concurrency), do: :erlang.nif_error(:undef) diff --git a/exla/mix.exs b/exla/mix.exs index 4036616379f..94ad24fa1dc 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -35,7 +35,8 @@ defmodule EXLA.MixProject do %{ "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", - "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv + "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv, + "EXLA_VERSION" => "#{@version}" } end, make_args: make_args @@ -133,7 +134,38 @@ defmodule EXLA.MixProject do {:ok, []} end - defp cached_make(_) do + defp cached_make(args) do + force_rebuild_mode = + case System.get_env("EXLA_FORCE_REBUILD", "") do + "" -> + :none + + "0" -> + :none + + "partial" -> + :partial + + "true" -> + :full + + "1" -> + :full + + value -> + Mix.raise( + "invalid value for EXLA_FORCE_REBUILD: '#{value}'. Expected one of: partial, true" + ) + end + + File.mkdir_p!("cache/#{@version}") + + # remove only in full mode + if force_rebuild_mode in [:partial, :full] do + Mix.shell().info("Removing cached .o files in cache/#{@version}/objs") + File.rm_rf!("cache/#{@version}/objs") + end + contents = for path <- Path.wildcard("c_src/**/*"), {:ok, contents} <- [File.read(path)], @@ -148,19 +180,27 @@ defmodule EXLA.MixProject do "elixir-#{System.version()}-erts-#{:erlang.system_info(:version)}-xla-#{Application.spec(:xla, :vsn)}-exla-#{@version}-#{md5}" cached_so = Path.join([xla_cache_dir(), "exla", cache_key, "libexla.so"]) - cached? = File.exists?(cached_so) + cached? = File.exists?(cached_so) and force_rebuild_mode == :none + + if force_rebuild_mode in [:partial, :full] do + Mix.shell().info("Removing cached libexla.so file in cache/#{@version}/libexla.so") + File.rm_rf!("cache/#{@version}/libexla.so") + + Mix.shell().info("Removing libexla.so cache at #{cached_so}") + File.rm!(cached_so) + end if cached? do Mix.shell().info("Using libexla.so from #{cached_so}") - File.cp!(cached_so, "cache/libexla.so") + File.cp!(cached_so, "cache/#{@version}/libexla.so") end - result = Mix.Tasks.Compile.ElixirMake.run([]) + result = Mix.Tasks.Compile.ElixirMake.run(args) if not cached? and match?({:ok, _}, result) do Mix.shell().info("Caching libexla.so at #{cached_so}") File.mkdir_p!(Path.dirname(cached_so)) - File.cp!("cache/libexla.so", cached_so) + File.cp!("cache/#{@version}/libexla.so", cached_so) end result diff --git a/exla/test/exla/device_memory_sharing_test.exs b/exla/test/exla/device_memory_sharing_test.exs index e986ea1ff83..09e54a42eb7 100644 --- a/exla/test/exla/device_memory_sharing_test.exs +++ b/exla/test/exla/device_memory_sharing_test.exs @@ -27,14 +27,13 @@ defmodule EXLA.DeviceMemorySharingTest do end @tag :cuda_required - test "ipc handles don't crash the runtime when :local mode is selected" do - assert {:error, ~c"Invalid pointer size for selected mode."} == + test "invalid ipc handles don't crash the runtime" do + assert {:error, ~c"Unable to get pointer for IPC handle."} == Nx.from_pointer( {EXLA.Backend, client: :cuda}, - Enum.to_list(0..63), + %Nx.Pointer{handle: "#{System.unique_integer()}", kind: :ipc, data_size: 4}, {:f, 32}, - {1}, - mode: :local + {1} ) end end From 81a7bb7a1f396725fe2bef0d85d5ec634b634f13 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:26:02 -0300 Subject: [PATCH 02/14] fix: least_squares implementation (#1550) --- exla/test/exla/nx_linalg_doctest_test.exs | 2 +- nx/lib/nx/lin_alg.ex | 24 +++++++++++------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 6df3aeec102..10c2cbce059 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -10,7 +10,7 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do invert: 1, matrix_power: 2 ] - @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2] + @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3] @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++ diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 0a353f43205..450a48cae42 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -2152,12 +2152,16 @@ defmodule Nx.LinAlg do @doc """ Return the least-squares solution to a linear matrix equation Ax = b. + ## Options + + * `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-15` + ## Examples iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[2] - [1.0000004768371582, -2.665601925855299e-7] + [0.9977624416351318, 0.0011188983917236328] > iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1])) @@ -2187,7 +2191,9 @@ defmodule Nx.LinAlg do ** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3} """ @doc from_backend: false - defn least_squares(a, b) do + defn least_squares(a, b, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-15) + %T{type: a_type, shape: a_shape} = Nx.to_tensor(a) a_size = Nx.rank(a_shape) %T{type: b_type, shape: b_shape} = Nx.to_tensor(b) @@ -2235,17 +2241,9 @@ defmodule Nx.LinAlg do ) end - case a_shape do - {m, n} when m == n -> - Nx.LinAlg.solve(a, b) - - {m, n} when m != n -> - Nx.LinAlg.pinv(a, eps: 1.0e-15) - |> Nx.dot(b) - - _ -> - nil - end + a + |> Nx.LinAlg.pinv(eps: opts[:eps]) + |> Nx.dot(b) end defp apply_vectorized(tensor, fun) when is_function(fun, 1) do From d64ba465175e90e85f6139e98c10ad64335c0dd4 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:19:58 -0300 Subject: [PATCH 03/14] feat(exla): add LU custom_call (#1549) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- exla/c_src/exla/custom_calls.cc | 10 ++- exla/c_src/exla/custom_calls/lu.h | 95 +++++++++++++++++++++++ exla/c_src/exla/custom_calls/lu_bf16.cc | 6 ++ exla/c_src/exla/custom_calls/lu_f16.cc | 6 ++ exla/c_src/exla/custom_calls/lu_f32.cc | 5 ++ exla/c_src/exla/custom_calls/lu_f64.cc | 5 ++ exla/c_src/exla/custom_calls/qr.h | 8 +- exla/lib/exla/defn.ex | 41 +++++++++- exla/lib/exla/mlir/value.ex | 75 ++++++++++++++++++ exla/mix.exs | 2 +- exla/test/exla/nx_linalg_doctest_test.exs | 22 ++++-- 11 files changed, 258 insertions(+), 17 deletions(-) create mode 100644 exla/c_src/exla/custom_calls/lu.h create mode 100644 exla/c_src/exla/custom_calls/lu_bf16.cc create mode 100644 exla/c_src/exla/custom_calls/lu_f16.cc create mode 100644 exla/c_src/exla/custom_calls/lu_f32.cc create mode 100644 exla/c_src/exla/custom_calls/lu_f64.cc diff --git a/exla/c_src/exla/custom_calls.cc b/exla/c_src/exla/custom_calls.cc index d5beee79581..8acd67cab6e 100644 --- a/exla/c_src/exla/custom_calls.cc +++ b/exla/c_src/exla/custom_calls.cc @@ -4,6 +4,10 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]); void qr_cpu_custom_call_f64(void *out[], const void *in[]); void qr_cpu_custom_call_f16(void *out[], const void *in[]); void qr_cpu_custom_call_bf16(void *out[], const void *in[]); +void lu_cpu_custom_call_f32(void *out[], const void *in[]); +void lu_cpu_custom_call_f64(void *out[], const void *in[]); +void lu_cpu_custom_call_f16(void *out[], const void *in[]); +void lu_cpu_custom_call_bf16(void *out[], const void *in[]); void eigh_cpu_custom_call_f32(void *out[], const void *in[]); void eigh_cpu_custom_call_f64(void *out[], const void *in[]); @@ -12,4 +16,8 @@ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_cu XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32); \ No newline at end of file +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f64", lu_cpu_custom_call_f64); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f32", lu_cpu_custom_call_f32); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f16", lu_cpu_custom_call_f16); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_bf16", lu_cpu_custom_call_bf16); \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls/lu.h b/exla/c_src/exla/custom_calls/lu.h new file mode 100644 index 00000000000..1c72565d4b5 --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu.h @@ -0,0 +1,95 @@ +#pragma once + +#include "Eigen/LU"; + +template +void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) { + typedef Eigen::Matrix RowMajorMatrix; + + Eigen::Map input(in, n, n); + Eigen::PartialPivLU lu = input.partialPivLu(); + + // Get the permutation matrix P and convert to indices + Eigen::PermutationMatrix P = lu.permutationP(); + for (uint64_t i = 0; i < n; i++) { + for (uint64_t j = 0; j < n; j++) { + p_out[i * n + j] = static_cast(P.indices()[i] == j ? 1 : 0); + } + } + + // Get L and U matrices + RowMajorMatrix L = lu.matrixLU().template triangularView(); + RowMajorMatrix U = lu.matrixLU().template triangularView(); + + // Copy L matrix + for (uint64_t i = 0; i < n; i++) { + for (uint64_t j = 0; j < n; j++) { + + if (j < i) { + l_out[i * n + j] = static_cast(L(i, j)); + } else if (j == i) { + l_out[i * n + j] = static_cast(1.0); + } else { + l_out[i * n + j] = static_cast(0.0); + } + } + } + + // Copy U matrix + for (uint64_t i = 0; i < n; i++) { + for (uint64_t j = 0; j < n; j++) { + if (j >= i) { + u_out[i * n + j] = static_cast(U(i, j)); + } else { + u_out[i * n + j] = static_cast(0.0); + } + } + } +} + +template +void lu_cpu_custom_call(void *out[], const void *in[]) { + DataType *operand = (DataType *)in[0]; + + uint64_t *dim_sizes = (uint64_t *)in[1]; + uint64_t num_operand_dims = dim_sizes[0]; + uint64_t num_p_dims = dim_sizes[1]; + uint64_t num_l_dims = dim_sizes[2]; + uint64_t num_u_dims = dim_sizes[3]; + + uint64_t *operand_dims_ptr = (uint64_t *)in[2]; + std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); + + uint64_t *p_dims_ptr = (uint64_t *)in[3]; + std::vector p_dims(p_dims_ptr, p_dims_ptr + num_p_dims); + + uint64_t *l_dims_ptr = (uint64_t *)in[4]; + std::vector l_dims(l_dims_ptr, l_dims_ptr + num_l_dims); + + uint64_t *u_dims_ptr = (uint64_t *)in[5]; + std::vector u_dims(u_dims_ptr, u_dims_ptr + num_u_dims); + + uint64_t n = l_dims[l_dims.size() - 1]; + + auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); + + uint64_t batch_items = 1; + for (uint64_t i = 0; i < leading_dimensions.size(); i++) { + batch_items *= leading_dimensions[i]; + } + + uint8_t *p = (uint8_t *)out[0]; + DataType *l = (DataType *)out[1]; + DataType *u = (DataType *)out[2]; + + uint64_t stride = n * n; + + for (uint64_t i = 0; i < batch_items; i++) { + single_matrix_lu_cpu_custom_call( + p + i * stride, + l + i * stride, + u + i * stride, + operand + i * stride, + n); + } +} \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls/lu_bf16.cc b/exla/c_src/exla/custom_calls/lu_bf16.cc new file mode 100644 index 00000000000..806f886b4cc --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_bf16.cc @@ -0,0 +1,6 @@ +#include "lu.h" +#include "../exla_types.h" + +void lu_cpu_custom_call_bf16(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/lu_f16.cc b/exla/c_src/exla/custom_calls/lu_f16.cc new file mode 100644 index 00000000000..81f6724e6e4 --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_f16.cc @@ -0,0 +1,6 @@ +#include "lu.h" +#include "../exla_types.h" + +void lu_cpu_custom_call_f16(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/lu_f32.cc b/exla/c_src/exla/custom_calls/lu_f32.cc new file mode 100644 index 00000000000..c506caab72f --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_f32.cc @@ -0,0 +1,5 @@ +#include "lu.h" + +void lu_cpu_custom_call_f32(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/lu_f64.cc b/exla/c_src/exla/custom_calls/lu_f64.cc new file mode 100644 index 00000000000..aed6ed2dabb --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_f64.cc @@ -0,0 +1,5 @@ +#include "lu.h" + +void lu_cpu_custom_call_f64(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/qr.h b/exla/c_src/exla/custom_calls/qr.h index 3615353ddf0..85e881447c5 100644 --- a/exla/c_src/exla/custom_calls/qr.h +++ b/exla/c_src/exla/custom_calls/qr.h @@ -73,15 +73,15 @@ void qr_cpu_custom_call(void *out[], const void *in[]) { DataType *q = (DataType *)out[0]; DataType *r = (DataType *)out[1]; - uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType); - uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType); - uint64_t inner_stride = m * n * sizeof(DataType); + uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2]; + uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2]; + uint64_t inner_stride = m * n; for (uint64_t i = 0; i < batch_items; i++) { single_matrix_qr_cpu_custom_call( (DataType *)out[0] + i * q_stride, (DataType *)out[1] + i * r_stride, - operand + i * inner_stride * sizeof(DataType), + operand + i * inner_stride, m, k, n, complete); } } \ No newline at end of file diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 2f73b255623..a37492e32bb 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -737,6 +737,43 @@ defmodule EXLA.Defn do end end + defp cached_recur_operator( + :lu, + %T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}}, + state, + cache + ) do + %{type: {p_type_kind, _}} = p_expr + %{type: {out_type_kind, _}} = l_expr + + if state.client.platform != :host do + raise ArgumentError, "XLA does not currently support the LU operation on non-host devices" + end + + if p_type_kind == :c or out_type_kind == :c do + raise ArgumentError, "XLA does not currently support the LU operation for complex inputs" + end + + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + tensor = + if op_type(tensor) != u_expr.type do + to_type(tensor, u_expr.type) + else + tensor + end + + {p, l, u} = + Value.lu( + tensor, + expr_to_typespec(p_expr), + expr_to_typespec(l_expr), + expr_to_typespec(u_expr) + ) + + {[p, l, u], cache} + end + defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do {op, cache} = recur_operator(expr, state, cache) {_, cache} = recur_operator(token, state, cache) @@ -965,10 +1002,6 @@ defmodule EXLA.Defn do end end - defp to_operator(:lu, [{_, _, _}, _tensor, _opts], _ans, _state) do - raise ArgumentError, "XLA does not currently support the LU operation" - end - ## to_operator element-wise defp to_operator(:negate, [%Value{} = op], ans, _state), diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 5dfd72ca237..e38d09fc0ba 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -815,6 +815,81 @@ defmodule EXLA.MLIR.Value do {q, r} end + def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do + %{type: op_type, shape: op_shape} = get_typespec(value) + %{type: _p_type, shape: p_shape} = p_typespec + %{type: l_type, shape: l_shape} = l_typespec + %{type: u_type, shape: u_shape} = u_typespec + + dim_sizes = [ + tuple_size(op_shape), + tuple_size(p_shape), + tuple_size(l_shape), + tuple_size(u_shape) + ] + + operand_dims = Tuple.to_list(op_shape) + p_dims = Tuple.to_list(p_shape) + l_dims = Tuple.to_list(l_shape) + u_dims = Tuple.to_list(u_shape) + + dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)})) + operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)})) + p_dims = constant(func, p_dims, Typespec.tensor({:u, 64}, {length(p_dims)})) + l_dims = constant(func, l_dims, Typespec.tensor({:u, 64}, {length(l_dims)})) + u_dims = constant(func, u_dims, Typespec.tensor({:u, 64}, {length(u_dims)})) + operands = [value, dim_sizes, operand_dims, p_dims, l_dims, u_dims] + + # Force P to always b u8 to avoid requiring too many template instances during custom_call registration + p_result_type = type_tensor({:u, 8}, p_shape) + l_result_type = type_tensor(l_type, l_shape) + u_result_type = type_tensor(u_type, u_shape) + result_types = [type_tuple([p_result_type, l_result_type, u_result_type])] + + call_target_name = + case op_type do + {:f, 32} -> + "lu_cpu_custom_call_f32" + + {:f, 64} -> + "lu_cpu_custom_call_f64" + + {:f, 16} -> + "lu_cpu_custom_call_f16" + + {:bf, 16} -> + "lu_cpu_custom_call_bf16" + + type -> + # Due to matching on EXLA.Defn, we are sure that the device here is always :host + raise "LU decomposition not supported on :host device for type #{inspect(type)}" + end + + attributes = [ + call_target_name: attr_string(call_target_name), + backend_config: attr_string("Host") + ] + + result = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() + + # This is not the best approach, but the alternative would require many more template instances + u8_typespec = Typespec.to_type(p_typespec, {:u, 8}) + p = get_tuple_element(result, 0, u8_typespec) + + p = + if u8_typespec != p_typespec do + convert(p, p_typespec) + else + p + end + + l = get_tuple_element(result, 1, l_typespec) + u = get_tuple_element(result, 2, u_typespec) + + {p, l, u} + end + def get_tuple_element(%Value{function: func} = operand, index, typespec) do result_types = typespecs_to_mlir_types([typespec]) attributes = [index: attr_i32(index)] diff --git a/exla/mix.exs b/exla/mix.exs index 94ad24fa1dc..0a09e463de6 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -187,7 +187,7 @@ defmodule EXLA.MixProject do File.rm_rf!("cache/#{@version}/libexla.so") Mix.shell().info("Removing libexla.so cache at #{cached_so}") - File.rm!(cached_so) + File.rm_rf!(cached_so) end if cached? do diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 10c2cbce059..09d60ba8f68 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -1,16 +1,24 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do use EXLA.Case, async: true - @invalid_type_error_doctests [svd: 2, pinv: 2, matrix_rank: 2] + @invalid_type_error_doctests [ + svd: 2, + pinv: 2 + ] + @function_clause_error_doctests [ - norm: 2, - lu: 2, - solve: 2, + solve: 2 + ] + + @rounding_error_doctests [ + triangular_solve: 3, + eigh: 2, + cholesky: 1, + least_squares: 3, determinant: 1, - invert: 1, - matrix_power: 2 + matrix_power: 2, + lu: 2 ] - @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3] @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++ From d21aca5dc1ee97ba1d0f378bc63a2600d4318b49 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 30 Oct 2024 02:24:53 -0300 Subject: [PATCH 04/14] fix: Nx.Random.shuffle repeating a single value in certain cases on GPU (#1552) Co-authored-by: Jonatan Klosko --- exla/lib/exla/defn.ex | 38 ++++++++++++++++++++++++------------- exla/lib/exla/mlir/value.ex | 31 +++++++++++++++++++----------- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index a37492e32bb..2f85143a4d8 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1560,30 +1560,42 @@ defmodule EXLA.Defn do ## Computation helpers - defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do + defp sort_computation(operator, type, arg_typespecs, %{ + builder: %EXLA.MLIR.Function{} = function + }) do {region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs) typespec = Typespec.tensor({:pred, 8}, {}) - op = - cond do - Nx.Type.integer?(type) -> - apply(Value, op, [lhs, rhs, typespec]) - - op == :less -> - is_nan = Value.is_nan(rhs, typespec) - Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec) - - op == :greater -> - is_nan = Value.is_nan(lhs, typespec) - Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec) + {lhs, rhs} = + if Nx.Type.integer?(type) do + {lhs, rhs} + else + {sort_computation_canonicalize_float(lhs), sort_computation_canonicalize_float(rhs)} end + op = apply(Value, operator, [lhs, rhs, typespec, [total_order: true]]) + Value.return(function, [op]) Function.pop_region(function) region end + defp sort_computation_canonicalize_float(%Value{function: func} = op) do + # Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0). + # See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253 + + op_typespec = Value.get_typespec(op) + + zero = Value.constant(func, [0], Typespec.to_shape(op_typespec, {})) + zeros = Value.constant(func, [0], op_typespec) + nans = Value.constant(func, [:nan], op_typespec) + + pred_typespec = Typespec.tensor({:pred, 8}, {}) + op = Value.select(Value.equal(op, zero, pred_typespec), zeros, op, op_typespec) + Value.select(Value.is_nan(op, pred_typespec), nans, op, op_typespec) + end + defp op_computation( op, arg_typespecs, diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index e38d09fc0ba..2b25c6f8f66 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -54,31 +54,40 @@ defmodule EXLA.MLIR.Value do } for {op, direction} <- @bin_comparison_ops do - def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do - compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction)) + def unquote(op)( + %Value{function: func} = lhs, + %Value{function: func} = rhs, + typespec, + opts \\ [] + ) do + compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order]) end end - defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do + defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do %{type: lhs_type} = get_typespec(lhs) %{type: rhs_type} = get_typespec(rhs) comparison_type = cond do Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) -> - attr_comparison_type(:float) + [compare_type: attr_comparison_type(:float)] Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) -> - attr_comparison_type(:float) + attr = + if total_order? do + attr_comparison_type(:totalorder) + else + attr_comparison_type(:float) + end + + [compare_type: attr] true -> - attr_comparison_type(:notype) + [] end - attributes = [ - comparison_direction: attr_comparison_direction(direction), - compare_type: comparison_type - ] + attributes = [comparison_direction: attr_comparison_direction(direction)] ++ comparison_type result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})]) @@ -1072,7 +1081,7 @@ defmodule EXLA.MLIR.Value do defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne], do: attr_enum("stablehlo", "comparison_direction", value) - defp attr_comparison_type(value) when value in [:float, :totalorder, :notype], + defp attr_comparison_type(value) when value in [:float, :totalorder], do: attr_enum("stablehlo", "comparison_type", value) defp attr_precision(value) when value in [:default, :high, :highest], From f9b401fb28595dface6c922fa29b94c4284483f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 30 Oct 2024 09:01:36 +0100 Subject: [PATCH 05/14] Fix race condition in serving tests (#1554) --- nx/test/nx/serving_test.exs | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/nx/test/nx/serving_test.exs b/nx/test/nx/serving_test.exs index 2b9a6375d8c..f6aa54368d5 100644 --- a/nx/test/nx/serving_test.exs +++ b/nx/test/nx/serving_test.exs @@ -607,9 +607,10 @@ defmodule Nx.ServingTest do # One task should succeed and the other terminate assert_receive {:DOWN, ref, _, _, {{%RuntimeError{}, _}, {Nx.Serving, :local_batched_run, [_, _]}}} - when ref in [ref1, ref2] - assert_receive {:DOWN, ref, _, _, :normal} when ref in [ref1, ref2] + assert [other_ref] = [ref1, ref2] -- [ref] + + assert_receive {:DOWN, ^other_ref, _, _, :normal} refute_received {:execute, _partition, _executor} end @@ -631,14 +632,14 @@ defmodule Nx.ServingTest do assert_receive {:execute, 0, executor} send(serving_pid, {:system, {self(), make_ref()}, {:terminate, :shutdown}}) - send(executor, :continue) - - # One task should succeed and the other terminate - assert_receive {:DOWN, ref, _, _, :normal} - when ref in [ref1, ref2] + # The queued caller should be terminated with :noproc right away assert_receive {:DOWN, ref, _, _, {:noproc, {Nx.Serving, :local_batched_run, [_, _]}}} - when ref in [ref1, ref2] + assert [other_ref] = [ref1, ref2] -- [ref] + + # The executing caller should be able to finish + send(executor, :continue) + assert_receive {:DOWN, ^other_ref, _, _, :normal} refute_received {:execute, _partition, _executor} end @@ -661,14 +662,14 @@ defmodule Nx.ServingTest do assert_receive {:execute, 0, executor} send(serving_pid, {:system, {self(), make_ref()}, {:terminate, :shutdown}}) - send(executor, :continue) - - # One task should succeed and the other terminate - assert_receive {:DOWN, ref, _, _, :normal} - when ref in [ref1, ref2] + # The stacked caller should be terminated with :noproc right away assert_receive {:DOWN, ref, _, _, {:noproc, {Nx.Serving, :local_batched_run, [_, _]}}} - when ref in [ref1, ref2] + assert [other_ref] = [ref1, ref2] -- [ref] + + # The executing caller should be able to finish + send(executor, :continue) + assert_receive {:DOWN, ^other_ref, _, _, :normal} refute_received {:execute, _partition, _executor} end From d4501d217fece095bd93b9b8af264d1449ea3c6c Mon Sep 17 00:00:00 2001 From: Ryan Moore Date: Wed, 30 Oct 2024 10:52:58 -0400 Subject: [PATCH 06/14] Fix documentation formatting (#1555) --- nx/lib/nx.ex | 98 ++++++++++++++++++++++++++-------------------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index beaa293d770..ce5cb99cebb 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -8682,7 +8682,7 @@ defmodule Nx do You may set the absolute tolerance, `:atol` and relative tolerance `:rtol`. Given tolerances, this functions returns 1 if - absolute(a - b) <= (atol + rtol * absolute(b)) + absolute(a - b) <= (atol + rtol * absolute(b)) is true for all elements of a and b. @@ -8695,67 +8695,67 @@ defmodule Nx do ## Examples - iex> Nx.all_close(Nx.tensor([1.0e10, 1.0e-7]), Nx.tensor([1.00001e10, 1.0e-8])) - #Nx.Tensor< - u8 - 0 - > + iex> Nx.all_close(Nx.tensor([1.0e10, 1.0e-7]), Nx.tensor([1.00001e10, 1.0e-8])) + #Nx.Tensor< + u8 + 0 + > - iex> Nx.all_close(Nx.tensor([1.0e-8, 1.0e-8]), Nx.tensor([1.0e-8, 1.0e-9])) - #Nx.Tensor< - u8 - 1 - > + iex> Nx.all_close(Nx.tensor([1.0e-8, 1.0e-8]), Nx.tensor([1.0e-8, 1.0e-9])) + #Nx.Tensor< + u8 + 1 + > Although `NaN` by definition isn't equal to itself, so this implementation also considers all `NaN`s different from each other by default: - iex> Nx.all_close(Nx.tensor(:nan), Nx.tensor(:nan)) - #Nx.Tensor< - u8 - 0 - > + iex> Nx.all_close(Nx.tensor(:nan), Nx.tensor(:nan)) + #Nx.Tensor< + u8 + 0 + > - iex> Nx.all_close(Nx.tensor(:nan), Nx.tensor(0)) - #Nx.Tensor< - u8 - 0 - > + iex> Nx.all_close(Nx.tensor(:nan), Nx.tensor(0)) + #Nx.Tensor< + u8 + 0 + > We can change this behavior with the `:equal_nan` option: - iex> t = Nx.tensor([:nan, 1]) - iex> Nx.all_close(t, t, equal_nan: true) # nan == nan -> true - #Nx.Tensor< - u8 - 1 - > - iex> Nx.all_close(t, t, equal_nan: false) # nan == nan -> false, default behavior - #Nx.Tensor< - u8 - 0 - > + iex> t = Nx.tensor([:nan, 1]) + iex> Nx.all_close(t, t, equal_nan: true) # nan == nan -> true + #Nx.Tensor< + u8 + 1 + > + iex> Nx.all_close(t, t, equal_nan: false) # nan == nan -> false, default behavior + #Nx.Tensor< + u8 + 0 + > Infinities behave as expected, being "close" to themselves but not to other numbers: - iex> Nx.all_close(Nx.tensor(:infinity), Nx.tensor(:infinity)) - #Nx.Tensor< - u8 - 1 - > - - iex> Nx.all_close(Nx.tensor(:infinity), Nx.tensor(:neg_infinity)) - #Nx.Tensor< - u8 - 0 - > - - iex> Nx.all_close(Nx.tensor(1.0e30), Nx.tensor(:infinity)) - #Nx.Tensor< - u8 - 0 - > + iex> Nx.all_close(Nx.tensor(:infinity), Nx.tensor(:infinity)) + #Nx.Tensor< + u8 + 1 + > + + iex> Nx.all_close(Nx.tensor(:infinity), Nx.tensor(:neg_infinity)) + #Nx.Tensor< + u8 + 0 + > + + iex> Nx.all_close(Nx.tensor(1.0e30), Nx.tensor(:infinity)) + #Nx.Tensor< + u8 + 0 + > ## Vectorized tensors From b63d02e5a68be6966724cb66dd64616eeef7b992 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:53:18 -0300 Subject: [PATCH 07/14] fix: exla rpath (#1553) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- exla/Makefile | 2 +- exla/mix.exs | 28 ++++++++++++---------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index 695c7f9409d..ada1bb7001a 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -8,7 +8,7 @@ XLA_EXTENSION_LIB = $(XLA_EXTENSION_DIR)/lib XLA_INCLUDE_PATH = $(XLA_EXTENSION_DIR)/include # Cache configuration -EXLA_CACHE_SO = cache/$(EXLA_VERSION)/libexla.so +EXLA_CACHE_SO = cache/libexla.so EXLA_CACHE_OBJ_DIR = cache/$(EXLA_VERSION)/objs # Private configuration diff --git a/exla/mix.exs b/exla/mix.exs index 0a09e463de6..d0c95c3ed50 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -135,26 +135,22 @@ defmodule EXLA.MixProject do end defp cached_make(args) do - force_rebuild_mode = - case System.get_env("EXLA_FORCE_REBUILD", "") do - "" -> - :none + force_rebuild_env_var = System.get_env("EXLA_FORCE_REBUILD", "") - "0" -> + force_rebuild_mode = + cond do + force_rebuild_env_var in ["", "false", "0"] -> :none - "partial" -> + force_rebuild_env_var == "partial" -> :partial - "true" -> - :full - - "1" -> + force_rebuild_env_var in ["true", "1"] -> :full - value -> + true -> Mix.raise( - "invalid value for EXLA_FORCE_REBUILD: '#{value}'. Expected one of: partial, true" + "invalid value for EXLA_FORCE_REBUILD: '#{force_rebuild_env_var}'. Expected one of: partial, true, false" ) end @@ -183,8 +179,8 @@ defmodule EXLA.MixProject do cached? = File.exists?(cached_so) and force_rebuild_mode == :none if force_rebuild_mode in [:partial, :full] do - Mix.shell().info("Removing cached libexla.so file in cache/#{@version}/libexla.so") - File.rm_rf!("cache/#{@version}/libexla.so") + Mix.shell().info("Removing cached libexla.so file in cache/libexla.so") + File.rm_rf!("cache/libexla.so") Mix.shell().info("Removing libexla.so cache at #{cached_so}") File.rm_rf!(cached_so) @@ -192,7 +188,7 @@ defmodule EXLA.MixProject do if cached? do Mix.shell().info("Using libexla.so from #{cached_so}") - File.cp!(cached_so, "cache/#{@version}/libexla.so") + File.cp!(cached_so, "cache/libexla.so") end result = Mix.Tasks.Compile.ElixirMake.run(args) @@ -200,7 +196,7 @@ defmodule EXLA.MixProject do if not cached? and match?({:ok, _}, result) do Mix.shell().info("Caching libexla.so at #{cached_so}") File.mkdir_p!(Path.dirname(cached_so)) - File.cp!("cache/#{@version}/libexla.so", cached_so) + File.cp!("cache/libexla.so", cached_so) end result From 3a485f135569b68231d4b613351a409afc1ab1fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 12 Nov 2024 11:02:54 +0100 Subject: [PATCH 08/14] Link to versioned XLA docs (#1558) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- exla/lib/exla.ex | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index f4714c1f8d2..1e2a55588c8 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -6,11 +6,15 @@ defmodule EXLA do ## XLA binaries - EXLA relies on the [XLA](https://github.com/elixir-nx/xla) package to - provide the necessary XLA binaries. Whenever possible it tries to download - precompiled builds, but you may need to build from source if there is no - version matching your target environment. For more details, including - GPU/TPU support see [the usage section](https://github.com/elixir-nx/xla#usage). + EXLA relies on the `XLA` package to provide the necessary XLA binaries. + Whenever possible it tries to download precompiled builds, but you may + need to build from source if there is no version matching your target + environment. For more details, including GPU/TPU support and requirements + see the `XLA` docs. + + > #### Version requirements {: .info} + > + > For precise requirements, such as CUDA and cuDNN versions, see `XLA` docs. ## Configuration From c40e25db935c59f31cc7c4b875c5e1d677e249b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 16 Nov 2024 11:51:41 +0100 Subject: [PATCH 09/14] Mark code as generated and remove dead code --- nx/lib/nx/binary_backend.ex | 29 +++++++---------------------- nx/lib/nx/shape.ex | 2 -- nx/lib/nx/shared.ex | 7 +++++++ 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 583b964bb69..5b44ee4936f 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2077,7 +2077,7 @@ defmodule Nx.BinaryBackend do for <>, into: <<>> do x = read!(x, 0) - case x do + generated_case x do %Complex{re: re} when float_output? and real_output? -> number_to_binary(re, output_type) @@ -2253,14 +2253,13 @@ defmodule Nx.BinaryBackend do end end - output_data = - match_types [out.type] do - for row <- result, %Complex{re: re, im: im} <- row, into: <<>> do - re = if abs(re) <= eps, do: 0, else: re - im = if abs(im) <= eps, do: 0, else: im + %{type: {_, output_size}} = out - <> - end + output_data = + for row <- result, %Complex{re: re, im: im} <- row, into: <<>> do + re = if abs(re) <= eps, do: 0, else: re + im = if abs(im) <= eps, do: 0, else: im + <> end intermediate_shape = out.shape |> Tuple.delete_at(axis) |> Tuple.append(n) @@ -2391,20 +2390,6 @@ defmodule Nx.BinaryBackend do end end - defp bin_zip_reduce(t1, [], t2, [], type, acc, fun) do - %{type: {_, s1}} = t1 - %{type: {_, s2}} = t2 - b1 = to_binary(t1) - b2 = to_binary(t2) - - match_types [t1.type, t2.type] do - for <>, <>, into: <<>> do - {result, _} = fun.(d1, d2, acc) - scalar_to_binary!(result, type) - end - end - end - defp bin_zip_reduce(t1, [_ | _] = axes1, t2, [_ | _] = axes2, type, acc, fun) do {_, s1} = t1.type {_, s2} = t2.type diff --git a/nx/lib/nx/shape.ex b/nx/lib/nx/shape.ex index 61d7eeb9415..3992822f6da 100644 --- a/nx/lib/nx/shape.ex +++ b/nx/lib/nx/shape.ex @@ -1739,8 +1739,6 @@ defmodule Nx.Shape do end) end - defp assert_non_concat_dims_equal([], _axis), do: :ok - defp assert_non_concat_dims_equal([s1 | shapes], axis) do s1_size = tuple_size(s1) template = Tuple.delete_at(s1, axis) diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 392e7416d36..48b814727c8 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -6,6 +6,13 @@ defmodule Nx.Shared do ## Type macros + defmacro generated_case(expr, do: clauses) do + clauses = + Enum.map(clauses, fn {:->, meta, args} -> {:->, [generated: true] ++ meta, args} end) + + {:case, [generated: true], [expr, [do: clauses]]} + end + @doc """ Match the cartesian product of all given types. From ca5eec7acfec1cd7e67a68f8f9bbaaf38c0aa1c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 16 Nov 2024 12:04:29 +0100 Subject: [PATCH 10/14] Fix deprecation warnings --- nx/lib/nx.ex | 14 +++++++------- nx/lib/nx/binary_backend.ex | 6 +++--- nx/lib/nx/defn/grad.ex | 4 ++-- nx/lib/nx/lin_alg/svd.ex | 6 +++--- nx/lib/nx/random.ex | 4 ++-- nx/lib/nx/shared.ex | 7 +++++++ 6 files changed, 24 insertions(+), 17 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index ce5cb99cebb..99209dc3add 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -1305,7 +1305,7 @@ defmodule Nx do out = case shape do {n} -> - intermediate_shape = Tuple.duplicate(1, tuple_size(out_shape) - 1) |> Tuple.append(n) + intermediate_shape = Tuple.duplicate(1, tuple_size(out_shape) - 1) |> tuple_append(n) backend.eye( %T{type: type, shape: intermediate_shape, names: names}, @@ -1609,7 +1609,7 @@ defmodule Nx do t else diag_length = div(Nx.size(t), Tuple.product(batch_shape)) - Nx.reshape(t, Tuple.append(batch_shape, diag_length)) + Nx.reshape(t, tuple_append(batch_shape, diag_length)) end end @@ -10365,9 +10365,9 @@ defmodule Nx do if opts[:keep_axis] do new_shape |> Tuple.delete_at(tuple_size(new_shape) - 1) - |> Tuple.append(:auto) + |> tuple_append(:auto) else - Tuple.append(new_shape, :auto) + tuple_append(new_shape, :auto) end reshaped_tensor = reshape(tensor, flattened_shape) @@ -13554,7 +13554,7 @@ defmodule Nx do end) |> Nx.stack() |> Nx.revectorize(vectorized_axes, - target_shape: Tuple.append(List.to_tuple(lengths), :auto) + target_shape: tuple_append(List.to_tuple(lengths), :auto) ) Nx.gather(tensor, idx) @@ -14288,7 +14288,7 @@ defmodule Nx do Nx.Shared.optional(:take_along_axis, [tensor, indices, [axis: axis]], out, fn tensor, indices, _opts -> axes_range = axes(indices) - new_axis_shape = Tuple.append(shape(indices), 1) + new_axis_shape = tuple_append(shape(indices), 1) full_indices = axes_range @@ -14471,7 +14471,7 @@ defmodule Nx do indices = devectorize(indices, keep_names: false) iota_shape = - indices.shape |> Tuple.delete_at(tuple_size(indices.shape) - 1) |> Tuple.append(1) + indices.shape |> Tuple.delete_at(tuple_size(indices.shape) - 1) |> tuple_append(1) offset_axes = (offset - 1)..0//-1 diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 5b44ee4936f..e3ae4121d6f 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1492,7 +1492,7 @@ defmodule Nx.BinaryBackend do dilations = opts[:window_dilations] %T{shape: padded_shape, type: {_, size} = type} = - tensor = Nx.pad(tensor, acc, Enum.map(padding_config, &Tuple.append(&1, 0))) + tensor = Nx.pad(tensor, acc, Enum.map(padding_config, &tuple_append(&1, 0))) acc = scalar_to_number(acc) @@ -1608,7 +1608,7 @@ defmodule Nx.BinaryBackend do init_value = scalar_to_number(init_value) %T{shape: padded_shape, type: {_, size} = type} = - tensor = Nx.pad(t, init_value, Enum.map(padding, &Tuple.append(&1, 0))) + tensor = Nx.pad(t, init_value, Enum.map(padding, &tuple_append(&1, 0))) input_data = to_binary(tensor) input_weighted_shape = weighted_shape(padded_shape, size, window_dimensions) @@ -2262,7 +2262,7 @@ defmodule Nx.BinaryBackend do <> end - intermediate_shape = out.shape |> Tuple.delete_at(axis) |> Tuple.append(n) + intermediate_shape = out.shape |> Tuple.delete_at(axis) |> tuple_append(n) permuted_output = from_binary(%{out | shape: intermediate_shape}, output_data) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 33a7a0deed1..25b4a1178bf 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -1157,7 +1157,7 @@ defmodule Nx.Defn.Grad do num_axes = tuple_size(window_dimensions) - indices = Nx.reshape(indices_to_flatten, Tuple.append(source.shape, num_axes)) + indices = Nx.reshape(indices_to_flatten, Nx.Shared.tuple_append(source.shape, num_axes)) dsource = Nx.gather(g, indices) dtensor = Nx.broadcast(0, tensor) @@ -1490,7 +1490,7 @@ defmodule Nx.Defn.Grad do end defp grad_scatter_window__gather_windows(tensor, window_dimensions, strides, padding) do - tensor = Nx.pad(tensor, 0, Enum.map(padding, &Tuple.append(&1, 0))) + tensor = Nx.pad(tensor, 0, Enum.map(padding, &Nx.Shared.tuple_append(&1, 0))) shape_l = Tuple.to_list(tensor.shape) window_dims_l = Tuple.to_list(window_dimensions) diff --git a/nx/lib/nx/lin_alg/svd.ex b/nx/lib/nx/lin_alg/svd.ex index 338720d0db4..7e1cad84f0d 100644 --- a/nx/lib/nx/lin_alg/svd.ex +++ b/nx/lib/nx/lin_alg/svd.ex @@ -59,9 +59,9 @@ defmodule Nx.LinAlg.SVD do collapsed_axes = shape |> Tuple.delete_at(rank - 2) |> Tuple.delete_at(rank - 2) - u_shape = collapsed_axes |> Tuple.append(m) |> Tuple.append(:auto) - s_shape = Tuple.append(collapsed_axes, :auto) - vt_shape = Tuple.append(s_shape, n) + u_shape = collapsed_axes |> Nx.Shared.tuple_append(m) |> Nx.Shared.tuple_append(:auto) + s_shape = Nx.Shared.tuple_append(collapsed_axes, :auto) + vt_shape = Nx.Shared.tuple_append(s_shape, n) {{m, n}, u_shape, s_shape, vt_shape} end diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index 5b429517cb2..2523c5114dd 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -746,7 +746,7 @@ defmodule Nx.Random do {dim} dims when is_tuple(dims) -> - Tuple.append(dims, dim) + Nx.Shared.tuple_append(dims, dim) _ -> raise ArgumentError, @@ -1126,7 +1126,7 @@ defmodule Nx.Random do case type do {:c, _} -> type = Nx.Type.to_real(type) - data = fun.(key, type, Tuple.append(shape, 2)) + data = fun.(key, type, Nx.Shared.tuple_append(shape, 2)) to_complex = Nx.stack([1, Nx.Constants.i()]) Nx.dot(data, to_complex) diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 48b814727c8..072f7db792a 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -445,6 +445,13 @@ defmodule Nx.Shared do ## Helpers + @doc """ + Appends an element to a tuple. + """ + def tuple_append(tuple, elem) do + Tuple.insert_at(tuple, tuple_size(tuple), elem) + end + @doc """ Extracts the backend from the given options. """ From f430b5b8d5ce859bbd692ca00efc6f8a236e47eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 16 Nov 2024 12:05:58 +0100 Subject: [PATCH 11/14] Fix test relying on map ordering --- nx/test/nx/defn_test.exs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index dbb560839b9..d532ea40434 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -1995,7 +1995,7 @@ defmodule Nx.DefnTest do defn while_mixed_return(a, b) do while {a, b}, Nx.less(a, 10) do - %{a: a, b: b} + %{"a" => a, "b" => b} end end @@ -2003,7 +2003,7 @@ defmodule Nx.DefnTest do expected_error = [ "the do-block in while must return tensors with the same shape, type, and names as the initial arguments.", - "\n\n\e\\[32m\n<<<<< Body \\(do-block\\) <<<<<\n%\\{a: #Nx.Tensor<\n s32\n >, b: #Nx.Tensor<\n s32\n >\\}", + "\n\n\e\\[32m\n<<<<< Body \\(do-block\\) <<<<<\n%\\{\"a\" => #Nx.Tensor<\n s32\n >, \"b\" => #Nx.Tensor<\n s32\n >\\}", "\n==========\n\e\\[31m\\{#Nx.Tensor<\n s32\n >, #Nx.Tensor<\n s32\n >\\}\n>>>>> Initial >>>>>\n\e\\[0m\n$" ] |> IO.iodata_to_binary() From 2bdcfe811a450d69f2b6df7b5923b1bdf08e9944 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 16 Nov 2024 12:11:37 +0100 Subject: [PATCH 12/14] Update deps --- nx/mix.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nx/mix.lock b/nx/mix.lock index fab455ebe4d..42e4470cd00 100644 --- a/nx/mix.lock +++ b/nx/mix.lock @@ -2,8 +2,8 @@ "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, - "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, + "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, + "makeup_elixir": {:hex, :makeup_elixir, "1.0.0", "74bb8348c9b3a51d5c589bf5aebb0466a84b33274150e3b6ece1da45584afc82", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "49159b7d7d999e836bedaf09dcf35ca18b312230cf901b725a64f3f42e407983"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, From 762c3c07022cd2391651af14713efaf3782352de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 16 Nov 2024 12:22:01 +0100 Subject: [PATCH 13/14] Fix warnings in Torchx and EXLA --- exla/lib/exla/defn.ex | 4 ++-- torchx/lib/torchx/backend.ex | 22 +++++++--------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 2f85143a4d8..d7ce33c0ab6 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -980,8 +980,8 @@ defmodule EXLA.Defn do transform = Keyword.fetch!(opts, :transform_a) case Value.get_typespec(b).shape do - {_} = b_shape -> - b_shape = Tuple.append(b_shape, 1) + {dim} -> + b_shape = {dim, 1} b = b diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 18f453dc6d7..77308ce94de 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -506,7 +506,7 @@ defmodule Torchx.Backend do result = if axes == [] do - aggregate_whole_tensor(t, keep_axes, &Torchx.product/1) + aggregate_whole_tensor(t, &Torchx.product/1) else aggregate_over_axes(t, axes, keep_axes, &Torchx.product/3) end @@ -523,7 +523,7 @@ defmodule Torchx.Backend do result = if axes == [] do - aggregate_whole_tensor(t, keep_axes, &Torchx.any/1) + aggregate_whole_tensor(t, &Torchx.any/1) else aggregate_over_axes(t, axes, keep_axes, &Torchx.any/3) end @@ -538,7 +538,7 @@ defmodule Torchx.Backend do result = if axes == [] do - aggregate_whole_tensor(t, keep_axes, &Torchx.all/1) + aggregate_whole_tensor(t, &Torchx.all/1) else aggregate_over_axes(t, axes, keep_axes, &Torchx.all/3) end @@ -563,18 +563,10 @@ defmodule Torchx.Backend do |> to_nx(out) end - defp aggregate_whole_tensor(t, keep_axes, fun) when is_function(fun, 1) do - result = - t - |> from_nx() - |> then(fun) - - if keep_axes do - shape = t.shape |> Tuple.delete_at(-1) |> Tuple.append(1) - Torchx.reshape(result, shape) - else - result - end + defp aggregate_whole_tensor(t, fun) when is_function(fun, 1) do + t + |> from_nx() + |> then(fun) end defp aggregate_over_axes(t, axes, keep_axes, fun) when is_function(fun, 3) do From be4ad6688bba8d35632be3dd04aef01c7f4c9b2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 16 Nov 2024 12:23:03 +0100 Subject: [PATCH 14/14] Release v0.9.2 --- exla/CHANGELOG.md | 7 +++++++ exla/mix.exs | 2 +- exla/mix.lock | 4 ++-- nx/CHANGELOG.md | 8 ++++++++ nx/mix.exs | 2 +- torchx/CHANGELOG.md | 4 ++++ torchx/mix.exs | 2 +- torchx/mix.lock | 4 ++-- 8 files changed, 26 insertions(+), 7 deletions(-) diff --git a/exla/CHANGELOG.md b/exla/CHANGELOG.md index 466505a84ec..a06ca4b41ea 100644 --- a/exla/CHANGELOG.md +++ b/exla/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v0.9.2 (2024-11-16) + +### Enhancements + + * Support cross-compilation for use with Nerves + * Optimize LU with a custom call + ## v0.9.1 (2024-10-08) ### Enhancements diff --git a/exla/mix.exs b/exla/mix.exs index d0c95c3ed50..a8244e332db 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -2,7 +2,7 @@ defmodule EXLA.MixProject do use Mix.Project @source_url "https://github.com/elixir-nx/nx" - @version "0.9.1" + @version "0.9.2" def project do make_args = diff --git a/exla/mix.lock b/exla/mix.lock index cb55ce51fc8..2ae6c6f7612 100644 --- a/exla/mix.lock +++ b/exla/mix.lock @@ -5,8 +5,8 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, - "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, + "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, + "makeup_elixir": {:hex, :makeup_elixir, "1.0.0", "74bb8348c9b3a51d5c589bf5aebb0466a84b33274150e3b6ece1da45584afc82", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "49159b7d7d999e836bedaf09dcf35ca18b312230cf901b725a64f3f42e407983"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, "nimble_pool": {:hex, :nimble_pool, "1.0.0", "5eb82705d138f4dd4423f69ceb19ac667b3b492ae570c9f5c900bb3d2f50a847", [:mix], [], "hexpm", "80be3b882d2d351882256087078e1b1952a28bf98d0a287be87e4a24a710b67a"}, diff --git a/nx/CHANGELOG.md b/nx/CHANGELOG.md index 69f8c81c16c..23aaa252687 100644 --- a/nx/CHANGELOG.md +++ b/nx/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v0.9.2 (2024-11-16) + +### Bug fixes + + * [Nx] Fix deprecation warnings on latest Elixir + * [Nx.LinAlg] Fix `least_squares` implementation + * [Nx.Random] Fix `Nx.Random.shuffle` repeating a single value in certain cases on GPU + ## v0.9.1 (2024-10-08) ### Deprecations diff --git a/nx/mix.exs b/nx/mix.exs index effc66eb4fd..d232c9c67a3 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -2,7 +2,7 @@ defmodule Nx.MixProject do use Mix.Project @source_url "https://github.com/elixir-nx/nx" - @version "0.9.1" + @version "0.9.2" def project do [ diff --git a/torchx/CHANGELOG.md b/torchx/CHANGELOG.md index db3c8f80782..f25ae812bf2 100644 --- a/torchx/CHANGELOG.md +++ b/torchx/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## v0.9.2 (2024-11-16) + + * Update to latest Nx + ## v0.9.1 (2024-10-08) * Update to latest Nx diff --git a/torchx/mix.exs b/torchx/mix.exs index 070718c3b03..79b0b5f7576 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -2,7 +2,7 @@ defmodule Torchx.MixProject do use Mix.Project @source_url "https://github.com/elixir-nx/nx" - @version "0.9.1" + @version "0.9.2" @libtorch_compilers [:torchx, :cmake] diff --git a/torchx/mix.lock b/torchx/mix.lock index b3a93517e0e..75bfa5fe46a 100644 --- a/torchx/mix.lock +++ b/torchx/mix.lock @@ -2,8 +2,8 @@ "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, - "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, + "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, + "makeup_elixir": {:hex, :makeup_elixir, "1.0.0", "74bb8348c9b3a51d5c589bf5aebb0466a84b33274150e3b6ece1da45584afc82", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "49159b7d7d999e836bedaf09dcf35ca18b312230cf901b725a64f3f42e407983"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, "nx": {:hex, :nx, "0.9.0", "03a622a27d93eaaa2d24ff9b812d9f675cc04eb0340ca3dd065674f3642867d3", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3810a5a90db0654b6e538430c0fb473a22bfc11b3d02ea7834db493cf3f56153"},