From 219c23ecd75f93a7ad714e9b7b1e01d4a0bdf4bd Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 6 Jun 2023 04:33:44 -0300 Subject: [PATCH] feat: improve container diffing (#1242) --- nx/lib/nx/defn.ex | 16 +-- nx/lib/nx/defn/expr.ex | 34 ++---- nx/lib/nx/defn/template_diff.ex | 160 +++++++++++++++++++++++++++++ nx/test/nx/container_test.exs | 16 ++- nx/test/nx/defn/evaluator_test.exs | 23 ++--- nx/test/nx/defn_test.exs | 134 ++++++++++++------------ 6 files changed, 254 insertions(+), 129 deletions(-) create mode 100644 nx/lib/nx/defn/template_diff.ex diff --git a/nx/lib/nx/defn.ex b/nx/lib/nx/defn.ex index f451a71ccc..aa121e487f 100644 --- a/nx/lib/nx/defn.ex +++ b/nx/lib/nx/defn.ex @@ -323,21 +323,7 @@ defmodule Nx.Defn do raise ArgumentError, """ argument at position #{pos} is not compatible with compiled function template. - Expected template: - - #{inspect(template)} - - Argument template: - - #{inspect(arg_template)} - - Expected argument: - - #{inspect(Enum.fetch!(template_args, pos - 1))} - - Actual argument: - - #{inspect(arg)} + #{Nx.Defn.TemplateDiff.build_and_inspect(Enum.fetch!(template_args, pos - 1), arg, "Expected", "Argument")} """ end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index ed627a108c..5bb68493c3 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -429,8 +429,9 @@ defmodule Nx.Defn.Expr do [{_, first_pred, first} | rest] -> first = first.() - [{last_pred, last} | reverse] = - Enum.reduce(rest, [{first_pred, first}], fn {meta, pred, expr}, acc -> + {[{last_pred, last} | reverse], _} = + Enum.reduce(rest, {[{first_pred, first}], 1}, fn {meta, pred, expr}, + {acc, branch_idx} -> expr = expr.() if not Nx.Defn.Composite.compatible?(first, expr, fn _, _ -> true end) do @@ -440,17 +441,11 @@ defmodule Nx.Defn.Expr do description: """ cond/if expects all branches to return compatible tensor types. - Got mismatching templates: - - #{inspect_as_template(first)} - - and - - #{inspect_as_template(expr)} + #{Nx.Defn.TemplateDiff.build_and_inspect(first, expr, "First Branch (expected)", "Branch #{branch_idx}", fn _, _ -> true end)} """ end - [{pred, expr} | acc] + {[{pred, expr} | acc], branch_idx + 1} end) case last_pred do @@ -765,13 +760,7 @@ defmodule Nx.Defn.Expr do description: """ the do-block in while must return tensors with the same shape, type, and names as the initial arguments. - Body matches template: - - #{inspect_as_template(body)} - - and initial argument has template: - - #{inspect_as_template(initial)} + #{Nx.Defn.TemplateDiff.build_and_inspect(body, initial, "Body (do-block)", "Initial")} """ end end @@ -1470,17 +1459,6 @@ defmodule Nx.Defn.Expr do context || acc end - defp inspect_as_template(data) do - if is_number(data) or is_tuple(data) or - (is_map(data) and Nx.Container.impl_for(data) != Nx.Container.Any) do - data - |> Nx.to_template() - |> Kernel.inspect(custom_options: [skip_template_backend_header: true]) - else - inspect(data) - end - end - ## Constant helpers and related optimizations defp constant(%{shape: shape, type: type} = out, number) do diff --git a/nx/lib/nx/defn/template_diff.ex b/nx/lib/nx/defn/template_diff.ex new file mode 100644 index 0000000000..71cc882d7f --- /dev/null +++ b/nx/lib/nx/defn/template_diff.ex @@ -0,0 +1,160 @@ +defmodule Nx.Defn.TemplateDiff do + @moduledoc false + defstruct [:left, :right, :left_title, :right_title, :compatible] + + defp is_valid_container?(impl) do + not is_nil(impl) and impl != Nx.Container.Any + end + + def build(left, right, left_title, right_title, compatibility_fn \\ &Nx.compatible?/2) do + left_impl = Nx.Container.impl_for(left) + right_impl = Nx.Container.impl_for(right) + + l = is_valid_container?(left_impl) + r = is_valid_container?(right_impl) + + cond do + not l and not r -> + %__MODULE__{ + left: left, + left_title: left_title, + right: right, + right_title: right_title, + compatible: left == right + } + + not l or not r -> + %__MODULE__{ + left: left, + left_title: left_title, + right: right, + right_title: right_title, + compatible: false + } + + left_impl != right_impl -> + %__MODULE__{ + left: left, + left_title: left_title, + right: right, + right_title: right_title, + compatible: false + } + + l and r -> + {diff, acc} = + Nx.Defn.Composite.traverse(left, Nx.Defn.Composite.flatten_list([right]), fn + left, [] -> + {%__MODULE__{left: left}, :incompatible_sizes} + + left, [right | acc] -> + { + %__MODULE__{ + left: left, + right: right, + left_title: left_title, + right_title: right_title, + compatible: compatibility_fn.(left, right) + }, + acc + } + end) + + if acc == :incompatible_sizes do + %__MODULE__{ + left: left, + left_title: left_title, + right: right, + right_title: right_title, + compatible: false + } + else + diff + end + end + end + + def build_and_inspect( + left, + right, + left_title, + right_title, + compatibility_fn \\ &Nx.compatible?/2 + ) do + left + |> build(right, left_title, right_title, compatibility_fn) + |> inspect() + end + + defimpl Inspect do + import Inspect.Algebra + + def inspect(%Nx.Defn.TemplateDiff{left: left, right: nil}, opts) do + inspect_as_template(left, opts) + end + + def inspect(%Nx.Defn.TemplateDiff{left: left, compatible: true}, opts) do + inspect_as_template(left, opts) + end + + def inspect( + %Nx.Defn.TemplateDiff{ + left: left, + left_title: left_title, + right: right, + right_title: right_title + }, + opts + ) do + {left_title, right_title} = centralize_titles(left_title, right_title) + + concat([ + IO.ANSI.green(), + line(), + "<<<<< #{left_title} <<<<<", + line(), + inspect_as_template(left, opts), + line(), + "==========", + line(), + IO.ANSI.red(), + inspect_as_template(right, opts), + line(), + ">>>>> #{right_title} >>>>>", + line(), + IO.ANSI.reset() + ]) + end + + defp centralize_titles(l, r) do + l_len = String.length(l) + r_len = String.length(r) + max_len = max(l_len, r_len) + + {centralize_string(l, l_len, max_len), centralize_string(r, r_len, max_len)} + end + + defp centralize_string(s, n, n), do: s + + defp centralize_string(s, l, n) do + pad = div(n - l, 2) + + s + |> String.pad_leading(l + pad) + |> String.pad_trailing(n) + end + + defp inspect_as_template(data, opts) do + if is_number(data) or is_tuple(data) or + (is_map(data) and Nx.Container.impl_for(data) != Nx.Container.Any) do + data + |> Nx.to_template() + |> to_doc( + update_in(opts.custom_options, &Keyword.put(&1, :skip_template_backend_header, true)) + ) + else + to_doc(data, opts) + end + end + end +end diff --git a/nx/test/nx/container_test.exs b/nx/test/nx/container_test.exs index 49c9d0333a..dc7e75f23c 100644 --- a/nx/test/nx/container_test.exs +++ b/nx/test/nx/container_test.exs @@ -141,7 +141,7 @@ defmodule Nx.ContainerTest do var.a + var.b end - deftransformp assert_fields!(%C{c: %{}, d: :keep}), do: 1 + deftransformp(assert_fields!(%C{c: %{}, d: :keep}), do: 1) test "keeps fields" do inp = %Container{a: 1, b: 2, c: :reset, d: :keep} @@ -158,7 +158,7 @@ defmodule Nx.ContainerTest do var.a + var.b end - deftransformp dot_assert_fields_transform(%C{c: %{}, d: %{}}), do: 1 + deftransformp(dot_assert_fields_transform(%C{c: %{}, d: %{}}), do: 1) test "keeps empty maps" do inp = %Container{a: 1, b: 2, c: :reset, d: %{}} @@ -194,10 +194,8 @@ defmodule Nx.ContainerTest do expected_error = [ "the do-block in while must return tensors with the same shape, type, and names as the initial arguments.", - "\n\nBody matches template:\n\n{#Nx.Tensor<\n s64\n >, ", - "%Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}, #Nx.Tensor<\n s16\n >}", - "\n\nand initial argument has template:\n\n{#Nx.Tensor<\n s64\n >, ", - "%Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}, #Nx.Tensor<\n u8\n >}\n$" + "\n\n{#Nx.Tensor<\n s64\n >, %Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >,", + " c: %{}, d: %{}}, \e\\[32m\n <<<<< Body \\(do-block\\) <<<<<\n #Nx.Tensor<\n s16\n >\n ==========\n \e\\[31m#Nx.Tensor<\n u8\n >\n >>>>> Initial >>>>>\n \e\\[0m}\n$" ] |> IO.iodata_to_binary() |> Regex.compile!() @@ -211,9 +209,9 @@ defmodule Nx.ContainerTest do expected_error = [ "cond/if expects all branches to return compatible tensor types.", - "\n\nGot mismatching templates:\n\n%Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}", - "\n\nand\n\n#Nx.Tensor<\n s64\n>\n", - "$" + "\n\n\e\\[32m\n<<<<< First Branch \\(expected\\) <<<<<\n%Container", + "{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}", + "\n==========\n\e\\[31m#Nx.Tensor<\n s64\n>\n>>>>> Branch 1 >>>>>\n\e\\[0m\n$" ] |> IO.iodata_to_binary() |> Regex.compile!() diff --git a/nx/test/nx/defn/evaluator_test.exs b/nx/test/nx/defn/evaluator_test.exs index 5a6a67b87c..29634f5b59 100644 --- a/nx/test/nx/defn/evaluator_test.exs +++ b/nx/test/nx/defn/evaluator_test.exs @@ -679,24 +679,19 @@ defmodule Nx.Defn.EvaluatorTest do message = """ test/nx/defn/evaluator_test.exs:646: the do-block in while must return tensors with the same shape, type, and names as the initial arguments. - Body matches template: - - {#Nx.Tensor< + {\e[32m + <<<<< Body (do-block) <<<<< + #Nx.Tensor< vectorized[a: 2] s64[2][3] - >, #Nx.Tensor< - vectorized[a: 2] - s64[2][3] - >, #Nx.Tensor< - s64 - >} - - and initial argument has template: - - {#Nx.Tensor< + > + ========== + \e[31m#Nx.Tensor< vectorized[a: 1] s64[2][3] - >, #Nx.Tensor< + > + >>>>> Initial >>>>> + \e[0m, #Nx.Tensor< vectorized[a: 2] s64[2][3] >, #Nx.Tensor< diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index 0847b4aa21..4b3617bb56 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -40,7 +40,7 @@ defmodule Nx.DefnTest do describe "Nx.tensor" do test "does not warn on negative values" do defmodule NegConstant do - defn this_wont_warn, do: Nx.tensor(1) + Nx.tensor(-1) + defn(this_wont_warn, do: Nx.tensor(1) + Nx.tensor(-1)) end end @@ -247,7 +247,7 @@ defmodule Nx.DefnTest do end describe "unary ops" do - defn exp(t), do: Nx.exp(t) + defn(exp(t), do: Nx.exp(t)) defn unary_plus_minus_guards(opts \\ []) do case opts[:value] do @@ -289,9 +289,9 @@ defmodule Nx.DefnTest do end describe "binary ops" do - defn add(t1, t2), do: Nx.add(t1, t2) - defn add_two_int(t), do: Nx.add(t, 2) - defn add_two_float(t), do: Nx.add(t, 2) + defn(add(t1, t2), do: Nx.add(t1, t2)) + defn(add_two_int(t), do: Nx.add(t, 2)) + defn(add_two_float(t), do: Nx.add(t, 2)) test "to expr" do assert %T{shape: {3}, type: {:s, 64}, data: %Expr{op: :add, args: [_, _]}} = @@ -317,10 +317,10 @@ defmodule Nx.DefnTest do end describe "aggregate axes ops" do - defn sum_all(t), do: Nx.sum(t) - defn sum_pos(t), do: Nx.sum(t, axes: [0, 1]) - defn sum_neg(t), do: Nx.sum(t, axes: [-1, -2]) - defn sum_keep(t), do: Nx.sum(t, axes: [0, 1], keep_axes: true) + defn(sum_all(t), do: Nx.sum(t)) + defn(sum_pos(t), do: Nx.sum(t, axes: [0, 1])) + defn(sum_neg(t), do: Nx.sum(t, axes: [-1, -2])) + defn(sum_keep(t), do: Nx.sum(t, axes: [0, 1], keep_axes: true)) test "to expr" do assert %T{ @@ -362,8 +362,8 @@ defmodule Nx.DefnTest do end describe "creation ops" do - defn iota(t), do: Nx.iota(Nx.shape(t)) - defn eye, do: Nx.eye(2) + defn(iota(t), do: Nx.iota(Nx.shape(t))) + defn(eye, do: Nx.eye(2)) test "iota" do assert %T{shape: {3}, data: %Expr{op: :iota, args: [nil]}} = iota(Nx.tensor([1, 2, 3])) @@ -381,11 +381,11 @@ defmodule Nx.DefnTest do end describe "tensor ops" do - defn dot2(t1, t2), do: Nx.dot(t1, t2) - defn dot6(t1, t2), do: Nx.dot(t1, [-2], [], t2, [-1], []) - defn transpose_1(t), do: Nx.transpose(t) - defn transpose_2(t), do: Nx.transpose(t, axes: [-1, -2]) - defn reshape(t), do: Nx.reshape(t, {2, 3}) + defn(dot2(t1, t2), do: Nx.dot(t1, t2)) + defn(dot6(t1, t2), do: Nx.dot(t1, [-2], [], t2, [-1], [])) + defn(transpose_1(t), do: Nx.transpose(t)) + defn(transpose_2(t), do: Nx.transpose(t, axes: [-1, -2])) + defn(reshape(t), do: Nx.reshape(t, {2, 3})) test "dot product" do assert %T{data: %Expr{op: :dot, args: [_, [0], _, _, [0], _]}, shape: {2}} = @@ -416,8 +416,8 @@ defmodule Nx.DefnTest do end describe "broadcast" do - defn broadcast(t), do: Nx.broadcast(t, {3, 3, 3}) - defn broadcast_axes(t), do: Nx.broadcast(t, {3, 2}, axes: [-2]) + defn(broadcast(t), do: Nx.broadcast(t, {3, 3, 3})) + defn(broadcast_axes(t), do: Nx.broadcast(t, {3, 2}, axes: [-2])) test "with and without axes" do assert %T{data: %Expr{op: :broadcast, args: [_, _, [2]]}, shape: {3, 3, 3}} = @@ -427,26 +427,32 @@ defmodule Nx.DefnTest do broadcast_axes(Nx.tensor([1, 2, 3])) end - defn broadcast_collapse1(t), do: t |> Nx.broadcast({5, 3}) |> Nx.broadcast({7, 5, 3}) + defn(broadcast_collapse1(t), do: t |> Nx.broadcast({5, 3}) |> Nx.broadcast({7, 5, 3})) - defn broadcast_collapse2(t), + defn(broadcast_collapse2(t), do: t |> Nx.broadcast({3, 5}, axes: [0]) |> Nx.broadcast({3, 5, 7}, axes: [0, 1]) + ) - defn broadcast_collapse3(t), + defn(broadcast_collapse3(t), do: t |> Nx.broadcast({3, 5}, axes: [0]) |> Nx.broadcast({3, 7, 5}, axes: [0, 2]) + ) - defn broadcast_collapse4(t), + defn(broadcast_collapse4(t), do: t |> Nx.broadcast({3, 5}, axes: [0]) |> Nx.broadcast({7, 3, 5}, axes: [1, 2]) + ) - defn broadcast_collapse5(t), + defn(broadcast_collapse5(t), do: t |> Nx.broadcast({5, 3}) |> Nx.broadcast({7, 5, 3, 9}, axes: [1, 2]) + ) - defn broadcast_collapse6(t), + defn(broadcast_collapse6(t), do: t |> Nx.broadcast({5, 3, 7}, axes: [1]) |> Nx.broadcast({9, 5, 3, 7}, axes: [1, 2, 3]) + ) - defn broadcast_collapse7(t), + defn(broadcast_collapse7(t), do: t |> Nx.broadcast({3, 5, 7}, axes: [0, 2]) |> Nx.broadcast({3, 9, 5, 7}, axes: [0, 2, 3]) + ) test "collapses" do assert %T{data: %Expr{op: :broadcast, args: [_, {7, 5, 3}, [1]]}, shape: {7, 5, 3}} = @@ -473,16 +479,16 @@ defmodule Nx.DefnTest do end describe "squeeze" do - defn squeeze(t), do: Nx.squeeze(t) + defn(squeeze(t), do: Nx.squeeze(t)) test "sized one dimensions" do assert %T{data: %Expr{op: :squeeze, args: [_, [0, 2, 4]]}, shape: {3, 2}} = squeeze(Nx.iota({1, 3, 1, 2, 1})) end - defn squeeze_collapse1(t), do: t |> Nx.squeeze(axes: [0, 2]) |> Nx.squeeze(axes: [0, 2]) - defn squeeze_collapse2(t), do: t |> Nx.squeeze(axes: [3, 1]) |> Nx.squeeze(axes: [2]) - defn squeeze_collapse3(t), do: t |> Nx.squeeze(axes: [2]) |> Nx.squeeze(axes: [3, 1]) + defn(squeeze_collapse1(t), do: t |> Nx.squeeze(axes: [0, 2]) |> Nx.squeeze(axes: [0, 2])) + defn(squeeze_collapse2(t), do: t |> Nx.squeeze(axes: [3, 1]) |> Nx.squeeze(axes: [2])) + defn(squeeze_collapse3(t), do: t |> Nx.squeeze(axes: [2]) |> Nx.squeeze(axes: [3, 1])) test "with explicit dimensions are collapsed" do assert %T{data: %Expr{op: :squeeze, args: [_, [0, 1, 2, 4]]}, shape: {1}, names: [:d]} = @@ -497,7 +503,7 @@ defmodule Nx.DefnTest do end describe "conditional ops" do - defn select(t1, t2, t3), do: Nx.select(t1, t2, t3) + defn(select(t1, t2, t3), do: Nx.select(t1, t2, t3)) test "select with tensor predicate" do assert %{data: %Expr{op: :select, args: [_, _, _]}, shape: {2, 2}} = @@ -514,16 +520,17 @@ defmodule Nx.DefnTest do end describe "reduce ops" do - defn reduce(t1, acc), do: Nx.reduce(t1, acc, fn x, y -> x + y end) + defn(reduce(t1, acc), do: Nx.reduce(t1, acc, fn x, y -> x + y end)) - defn reduce_static(t1, acc), do: Nx.reduce(t1, acc, fn _, _ -> 0 end) + defn(reduce_static(t1, acc), do: Nx.reduce(t1, acc, fn _, _ -> 0 end)) - defn reduce_invalid(t1, amplifier), do: Nx.reduce(t1, 0, fn x, y -> x * amplifier + y end) + defn(reduce_invalid(t1, amplifier), do: Nx.reduce(t1, 0, fn x, y -> x * amplifier + y end)) - defn reduce_non_scalar(t1), do: Nx.reduce(t1, 0, fn x, y -> Nx.broadcast(x * y, {1, 1}) end) + defn(reduce_non_scalar(t1), do: Nx.reduce(t1, 0, fn x, y -> Nx.broadcast(x * y, {1, 1}) end)) - defn reduce_with_opts(t1, acc), + defn(reduce_with_opts(t1, acc), do: Nx.reduce(t1, acc, [type: {:f, 64}, axes: [-1]], fn x, y -> x + y end) + ) test "reduces with function" do assert %{ @@ -575,31 +582,31 @@ defmodule Nx.DefnTest do end describe "operators" do - defn add_two(a, b), do: a + b + defn(add_two(a, b), do: a + b) test "+" do assert %T{data: %Expr{op: :add, args: [_, _]}} = add_two(1, 2) end - defn subtract_two(a, b), do: a - b + defn(subtract_two(a, b), do: a - b) test "-" do assert %T{data: %Expr{op: :subtract, args: [_, _]}} = subtract_two(1, 2) end - defn multiply_two(a, b), do: a * b + defn(multiply_two(a, b), do: a * b) test "*" do assert %T{data: %Expr{op: :multiply, args: [_, _]}} = multiply_two(1, 2) end - defn divide_two(a, b), do: a / b + defn(divide_two(a, b), do: a / b) test "/" do assert %T{data: %Expr{op: :divide, args: [_, _]}} = divide_two(1, 2) end - defn land_two(a, b), do: a and b + defn(land_two(a, b), do: a and b) defn land_true(a) do true and a @@ -609,7 +616,7 @@ defmodule Nx.DefnTest do assert %T{data: %Expr{op: :logical_and, args: [_, _]}} = land_two(1, 2) end - defn lor_two(a, b), do: a or b + defn(lor_two(a, b), do: a or b) defn lor_true(opts \\ []) do true or opts[:value] @@ -619,8 +626,8 @@ defmodule Nx.DefnTest do assert %T{data: %Expr{op: :logical_or, args: [_, _]}} = lor_two(1, 2) end - defn lnot(a), do: not a - defn lnot_boolean(opts \\ []), do: not constant_boolean_transform(opts) + defn(lnot(a), do: not a) + defn(lnot_boolean(opts \\ []), do: not constant_boolean_transform(opts)) deftransformp constant_boolean_transform(opts) do if opts[:value] == true do @@ -630,81 +637,81 @@ defmodule Nx.DefnTest do end end - defn band_two(a, b), do: a &&& b + defn(band_two(a, b), do: a &&& b) test "&&&" do assert %T{data: %Expr{op: :bitwise_and, args: [_, _]}} = band_two(1, 2) end - defn bor_two(a, b), do: a ||| b + defn(bor_two(a, b), do: a ||| b) test "|||" do assert %T{data: %Expr{op: :bitwise_or, args: [_, _]}} = bor_two(1, 2) end - defn bsl_two(a, b), do: a <<< b + defn(bsl_two(a, b), do: a <<< b) test "<<<" do assert %T{data: %Expr{op: :left_shift, args: [_, _]}} = bsl_two(1, 2) end - defn bsr_two(a, b), do: a >>> b + defn(bsr_two(a, b), do: a >>> b) test ">>>" do assert %T{data: %Expr{op: :right_shift, args: [_, _]}} = bsr_two(1, 2) end - defn add_two_with_pipe(a, b), do: a |> Nx.add(b) + defn(add_two_with_pipe(a, b), do: a |> Nx.add(b)) test "|>" do assert %T{data: %Expr{op: :add, args: [_, _]}} = add_two_with_pipe(1, 2) end - defn unary_plus(a), do: +a - defn unary_minus(a), do: -a + defn(unary_plus(a), do: +a) + defn(unary_minus(a), do: -a) test "unary plus and minus" do assert %T{data: %Expr{op: :parameter, args: [_]}} = unary_plus(1) assert %T{data: %Expr{op: :negate, args: [_]}} = unary_minus(1) end - defn unary_bnot(a), do: ~~~a + defn(unary_bnot(a), do: ~~~a) test "~~~" do assert %T{data: %Expr{op: :bitwise_not, args: [_]}} = unary_bnot(1) end - defn equality(a, b), do: a == b + defn(equality(a, b), do: a == b) test "==" do assert %T{data: %Expr{op: :equal, args: [_, _]}} = equality(1, 2) end - defn inequality(a, b), do: a != b + defn(inequality(a, b), do: a != b) test "!=" do assert %T{data: %Expr{op: :not_equal, args: [_, _]}} = inequality(1, 2) end - defn less_than(a, b), do: a < b + defn(less_than(a, b), do: a < b) test "<" do assert %T{data: %Expr{op: :less, args: [_, _]}} = less_than(1, 2) end - defn greater_than(a, b), do: a > b + defn(greater_than(a, b), do: a > b) test ">" do assert %T{data: %Expr{op: :greater, args: [_, _]}} = greater_than(1, 2) end - defn less_than_or_equal(a, b), do: a <= b + defn(less_than_or_equal(a, b), do: a <= b) test "<=" do assert %T{data: %Expr{op: :less_equal, args: [_, _]}} = less_than_or_equal(1, 2) end - defn greater_than_or_equal(a, b), do: a >= b + defn(greater_than_or_equal(a, b), do: a >= b) test ">=" do assert %T{data: %Expr{op: :greater_equal, args: [_, _]}} = greater_than_or_equal(1, 2) @@ -1246,7 +1253,7 @@ defmodule Nx.DefnTest do test "raises on non-tensor return" do assert_raise CompileError, - ~r"cond/if expects all branches to return compatible tensor types.\n\nGot mismatching templates:\n\n:foo\n\nand\n\n:bar\n", + ~r"cond/if expects all branches to return compatible tensor types.\n\n\e\[32m\n<<<<< First Branch \(expected\) <<<<<\n:foo\n==========\n\e\[31m:bar\n>>>>> Branch 1 >>>>>\n\e\[0m\n", fn -> non_tensor_cond(1) end end @@ -1565,9 +1572,10 @@ defmodule Nx.DefnTest do expected_error = [ - "the do-block in while must return tensors with the same shape, type, and names as the initial arguments.", - "\n\nBody matches template:\n\n{#Nx.Tensor<\n f32\n >, #Nx.Tensor<\n f32\n >}", - "\n\nand initial argument has template:\n\n{#Nx.Tensor<\n s64\n >, #Nx.Tensor<\n f32\n >}\n" + "the do-block in while must return tensors with the same shape, type, and names ", + "as the initial arguments.\n\n\\{\e\\[32m\n <<<<< Body \\(do-block\\) <<<<<\n ", + "#Nx.Tensor<\n f32\n >\n ==========\n \e\\[31m#Nx.Tensor<\n s64\n >\n >>>>>", + " Initial >>>>>\n \e\\[0m, #Nx.Tensor<\n f32\n >\\}\n$" ] |> IO.iodata_to_binary() |> Regex.compile!() @@ -1987,8 +1995,8 @@ defmodule Nx.DefnTest do expected_error = [ "the do-block in while must return tensors with the same shape, type, and names as the initial arguments.", - "\n\nBody matches template:\n\n%{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >}", - "\n\nand initial argument has template:\n\n{#Nx.Tensor<\n s64\n >, #Nx.Tensor<\n s64\n >}\n" + "\n\n\e\\[32m\n<<<<< Body \\(do-block\\) <<<<<\n%\\{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >\\}", + "\n==========\n\e\\[31m\\{#Nx.Tensor<\n s64\n >, #Nx.Tensor<\n s64\n >\\}\n>>>>> Initial >>>>>\n\e\\[0m\n$" ] |> IO.iodata_to_binary() |> Regex.compile!()