Skip to content

Commit

Permalink
refactor: support vectorized constant
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Sep 16, 2024
1 parent 20cc168 commit 22b9a24
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
5 changes: 5 additions & 0 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,11 @@ defmodule Nx.Defn.Expr do

## Constant helpers and related optimizations

defp constant(%{vectorized_axes: [_ | _]} = out, number) do
out = %{out | names: Enum.map(out.names, fn _ -> nil end)}
tensor(Nx.fill(out, number, type: out.type))
end

defp constant(%{shape: shape, type: type} = out, number) do
number =
cond do
Expand Down
39 changes: 22 additions & 17 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,10 @@ defmodule Nx.Defn.Grad do
{expr, graded}
end

defp constant(float, shape) do
case shape do
%T{vectorized_axes: [_ | _]} = t ->
Expr.tensor(Nx.fill(t, float, type: :f32))

t ->
shape = Nx.shape(t)
names = List.duplicate(nil, tuple_size(shape))
Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, [])
end
defp constant(float, %T{shape: shape} = t) do
names = List.duplicate(nil, tuple_size(shape))

Expr.constant(%T{t | names: names, type: {:f, 32}}, float, [])
end

defp validate_expr!(%T{data: %Expr{}} = expr) do
Expand Down Expand Up @@ -1351,22 +1345,33 @@ defmodule Nx.Defn.Grad do
%T{names: names} ->
names = Enum.with_index(names, fn name, idx -> if(name, do: {name, idx}) end)

vectorized_axes =
{vectorized_axes, offset} =
names
|> Enum.reduce([], fn
|> Enum.reduce({[], 0}, fn
nil, acc ->
acc

{name, _idx}, acc ->
{name, _idx}, {acc, count} ->
if name in names_ans do
[name | acc]
{[name | acc], count + 1}
else
acc
{acc, count}
end
end)
|> Enum.reverse()

Nx.vectorize(arg, vectorized_axes)
axes_names = Enum.reverse(vectorized_axes)

{vec_shape_list, shape_list} = arg.shape |> Tuple.to_list() |> Enum.split(offset)

vectorized_axes =
Enum.zip(axes_names, vec_shape_list)

%{
arg
| vectorized_axes: vectorized_axes,
names: Enum.drop(arg.names, offset),
shape: List.to_tuple(shape_list)
}

arg ->
arg
Expand Down

0 comments on commit 22b9a24

Please sign in to comment.