Skip to content

Commit

Permalink
feat: improve container diffing (#1242)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Jun 6, 2023
1 parent 7cdbc1b commit 219c23e
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 129 deletions.
16 changes: 1 addition & 15 deletions nx/lib/nx/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,7 @@ defmodule Nx.Defn do
raise ArgumentError, """
argument at position #{pos} is not compatible with compiled function template.
Expected template:
#{inspect(template)}
Argument template:
#{inspect(arg_template)}
Expected argument:
#{inspect(Enum.fetch!(template_args, pos - 1))}
Actual argument:
#{inspect(arg)}
#{Nx.Defn.TemplateDiff.build_and_inspect(Enum.fetch!(template_args, pos - 1), arg, "Expected", "Argument")}
"""
end

Expand Down
34 changes: 6 additions & 28 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,9 @@ defmodule Nx.Defn.Expr do
[{_, first_pred, first} | rest] ->
first = first.()

[{last_pred, last} | reverse] =
Enum.reduce(rest, [{first_pred, first}], fn {meta, pred, expr}, acc ->
{[{last_pred, last} | reverse], _} =
Enum.reduce(rest, {[{first_pred, first}], 1}, fn {meta, pred, expr},
{acc, branch_idx} ->
expr = expr.()

if not Nx.Defn.Composite.compatible?(first, expr, fn _, _ -> true end) do
Expand All @@ -440,17 +441,11 @@ defmodule Nx.Defn.Expr do
description: """
cond/if expects all branches to return compatible tensor types.
Got mismatching templates:
#{inspect_as_template(first)}
and
#{inspect_as_template(expr)}
#{Nx.Defn.TemplateDiff.build_and_inspect(first, expr, "First Branch (expected)", "Branch #{branch_idx}", fn _, _ -> true end)}
"""
end

[{pred, expr} | acc]
{[{pred, expr} | acc], branch_idx + 1}
end)

case last_pred do
Expand Down Expand Up @@ -765,13 +760,7 @@ defmodule Nx.Defn.Expr do
description: """
the do-block in while must return tensors with the same shape, type, and names as the initial arguments.
Body matches template:
#{inspect_as_template(body)}
and initial argument has template:
#{inspect_as_template(initial)}
#{Nx.Defn.TemplateDiff.build_and_inspect(body, initial, "Body (do-block)", "Initial")}
"""
end
end
Expand Down Expand Up @@ -1470,17 +1459,6 @@ defmodule Nx.Defn.Expr do
context || acc
end

defp inspect_as_template(data) do
if is_number(data) or is_tuple(data) or
(is_map(data) and Nx.Container.impl_for(data) != Nx.Container.Any) do
data
|> Nx.to_template()
|> Kernel.inspect(custom_options: [skip_template_backend_header: true])
else
inspect(data)
end
end

## Constant helpers and related optimizations

defp constant(%{shape: shape, type: type} = out, number) do
Expand Down
160 changes: 160 additions & 0 deletions nx/lib/nx/defn/template_diff.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
defmodule Nx.Defn.TemplateDiff do
@moduledoc false
defstruct [:left, :right, :left_title, :right_title, :compatible]

defp is_valid_container?(impl) do
not is_nil(impl) and impl != Nx.Container.Any
end

def build(left, right, left_title, right_title, compatibility_fn \\ &Nx.compatible?/2) do
left_impl = Nx.Container.impl_for(left)
right_impl = Nx.Container.impl_for(right)

l = is_valid_container?(left_impl)
r = is_valid_container?(right_impl)

cond do
not l and not r ->
%__MODULE__{
left: left,
left_title: left_title,
right: right,
right_title: right_title,
compatible: left == right
}

not l or not r ->
%__MODULE__{
left: left,
left_title: left_title,
right: right,
right_title: right_title,
compatible: false
}

left_impl != right_impl ->
%__MODULE__{
left: left,
left_title: left_title,
right: right,
right_title: right_title,
compatible: false
}

l and r ->
{diff, acc} =
Nx.Defn.Composite.traverse(left, Nx.Defn.Composite.flatten_list([right]), fn
left, [] ->
{%__MODULE__{left: left}, :incompatible_sizes}

left, [right | acc] ->
{
%__MODULE__{
left: left,
right: right,
left_title: left_title,
right_title: right_title,
compatible: compatibility_fn.(left, right)
},
acc
}
end)

if acc == :incompatible_sizes do
%__MODULE__{
left: left,
left_title: left_title,
right: right,
right_title: right_title,
compatible: false
}
else
diff
end
end
end

def build_and_inspect(
left,
right,
left_title,
right_title,
compatibility_fn \\ &Nx.compatible?/2
) do
left
|> build(right, left_title, right_title, compatibility_fn)
|> inspect()
end

defimpl Inspect do
import Inspect.Algebra

def inspect(%Nx.Defn.TemplateDiff{left: left, right: nil}, opts) do
inspect_as_template(left, opts)
end

def inspect(%Nx.Defn.TemplateDiff{left: left, compatible: true}, opts) do
inspect_as_template(left, opts)
end

def inspect(
%Nx.Defn.TemplateDiff{
left: left,
left_title: left_title,
right: right,
right_title: right_title
},
opts
) do
{left_title, right_title} = centralize_titles(left_title, right_title)

concat([
IO.ANSI.green(),
line(),
"<<<<< #{left_title} <<<<<",
line(),
inspect_as_template(left, opts),
line(),
"==========",
line(),
IO.ANSI.red(),
inspect_as_template(right, opts),
line(),
">>>>> #{right_title} >>>>>",
line(),
IO.ANSI.reset()
])
end

defp centralize_titles(l, r) do
l_len = String.length(l)
r_len = String.length(r)
max_len = max(l_len, r_len)

{centralize_string(l, l_len, max_len), centralize_string(r, r_len, max_len)}
end

defp centralize_string(s, n, n), do: s

defp centralize_string(s, l, n) do
pad = div(n - l, 2)

s
|> String.pad_leading(l + pad)
|> String.pad_trailing(n)
end

defp inspect_as_template(data, opts) do
if is_number(data) or is_tuple(data) or
(is_map(data) and Nx.Container.impl_for(data) != Nx.Container.Any) do
data
|> Nx.to_template()
|> to_doc(
update_in(opts.custom_options, &Keyword.put(&1, :skip_template_backend_header, true))
)
else
to_doc(data, opts)
end
end
end
end
16 changes: 7 additions & 9 deletions nx/test/nx/container_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ defmodule Nx.ContainerTest do
var.a + var.b
end

deftransformp assert_fields!(%C{c: %{}, d: :keep}), do: 1
deftransformp(assert_fields!(%C{c: %{}, d: :keep}), do: 1)

test "keeps fields" do
inp = %Container{a: 1, b: 2, c: :reset, d: :keep}
Expand All @@ -158,7 +158,7 @@ defmodule Nx.ContainerTest do
var.a + var.b
end

deftransformp dot_assert_fields_transform(%C{c: %{}, d: %{}}), do: 1
deftransformp(dot_assert_fields_transform(%C{c: %{}, d: %{}}), do: 1)

test "keeps empty maps" do
inp = %Container{a: 1, b: 2, c: :reset, d: %{}}
Expand Down Expand Up @@ -194,10 +194,8 @@ defmodule Nx.ContainerTest do
expected_error =
[
"the do-block in while must return tensors with the same shape, type, and names as the initial arguments.",
"\n\nBody matches template:\n\n{#Nx.Tensor<\n s64\n >, ",
"%Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}, #Nx.Tensor<\n s16\n >}",
"\n\nand initial argument has template:\n\n{#Nx.Tensor<\n s64\n >, ",
"%Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}, #Nx.Tensor<\n u8\n >}\n$"
"\n\n{#Nx.Tensor<\n s64\n >, %Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >,",
" c: %{}, d: %{}}, \e\\[32m\n <<<<< Body \\(do-block\\) <<<<<\n #Nx.Tensor<\n s16\n >\n ==========\n \e\\[31m#Nx.Tensor<\n u8\n >\n >>>>> Initial >>>>>\n \e\\[0m}\n$"
]
|> IO.iodata_to_binary()
|> Regex.compile!()
Expand All @@ -211,9 +209,9 @@ defmodule Nx.ContainerTest do
expected_error =
[
"cond/if expects all branches to return compatible tensor types.",
"\n\nGot mismatching templates:\n\n%Container{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}",
"\n\nand\n\n#Nx.Tensor<\n s64\n>\n",
"$"
"\n\n\e\\[32m\n<<<<< First Branch \\(expected\\) <<<<<\n%Container",
"{a: #Nx.Tensor<\n s64\n >, b: #Nx.Tensor<\n s64\n >, c: %{}, d: %{}}",
"\n==========\n\e\\[31m#Nx.Tensor<\n s64\n>\n>>>>> Branch 1 >>>>>\n\e\\[0m\n$"
]
|> IO.iodata_to_binary()
|> Regex.compile!()
Expand Down
23 changes: 9 additions & 14 deletions nx/test/nx/defn/evaluator_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -679,24 +679,19 @@ defmodule Nx.Defn.EvaluatorTest do
message = """
test/nx/defn/evaluator_test.exs:646: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.
Body matches template:
{#Nx.Tensor<
{\e[32m
<<<<< Body (do-block) <<<<<
#Nx.Tensor<
vectorized[a: 2]
s64[2][3]
>, #Nx.Tensor<
vectorized[a: 2]
s64[2][3]
>, #Nx.Tensor<
s64
>}
and initial argument has template:
{#Nx.Tensor<
>
==========
\e[31m#Nx.Tensor<
vectorized[a: 1]
s64[2][3]
>, #Nx.Tensor<
>
>>>>> Initial >>>>>
\e[0m, #Nx.Tensor<
vectorized[a: 2]
s64[2][3]
>, #Nx.Tensor<
Expand Down
Loading

0 comments on commit 219c23e

Please sign in to comment.