Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: finish MLIR scaffolding for EXLA.Defn #1293

Merged
merged 16 commits into from
Aug 31, 2023
Merged
46 changes: 23 additions & 23 deletions exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,28 +287,28 @@ xla::StatusOr<ERL_NIF_TERM> ExlaExecutable::Run(ErlNifEnv* env,

std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> per_replica_results;

EXLA_ASSIGN_OR_RETURN(per_replica_results, executable_->Execute(input_buffers, options));

// if (device_id >= 0) {
// // if we specified a device id, then we need to execute the executable as a portable
// // executable, meaning we need to find the device corresponding to the specific device
// // id and execute on that device, we've already guaranteed this executable only has 1
// // replica
// EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice* device, client_->client()->LookupDevice(device_id));
// // because this is a portable executable, it only has 1 replica and so we only need
// // to get the arguments at the first position of the input buffers
// std::vector<xla::PjRtBuffer *> portable_args = input_buffers.at(0);
// EXLA_ASSIGN_OR_RETURN(auto portable_result,
// executable_->ExecutePortable(portable_args, device, options));
// // the logic for handling unpacking of results is shared between portable code path
// // and the replicated code-path, so we take ownership of the result buffers to unpack
// per_replica_results.push_back(std::move(portable_result));
// } else {
// // no device ID is present, so it may be a replicated executable which means we need
// // to use the replica execution path
// // TODO: This now exposes a `returned_futures` API, does this make sense for us?
// EXLA_ASSIGN_OR_RETURN(per_replica_results, executable_->Execute(input_buffers, options));
// }
if (device_id >= 0) {
// if we specified a device id, then we need to execute the executable as a portable
// executable, meaning we need to find the device corresponding to the specific device
// id and execute on that device, we've already guaranteed this executable only has 1
// replica
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice* device, client_->client()->LookupDevice(device_id));
// because this is a portable executable, it only has 1 replica and so we only need
// to get the arguments at the first position of the input buffers
std::vector<xla::PjRtBuffer *> portable_args = input_buffers.at(0);
EXLA_ASSIGN_OR_RETURN(auto portable_result,
executable_->ExecutePortable(portable_args, device, options));
// the logic for handling unpacking of results is shared between portable code path
// and the replicated code-path, so we take ownership of the result buffers to unpack
per_replica_results.push_back(std::move(portable_result));
} else {
// no device ID is present, so it may be a replicated executable which means we need
// to use the replica execution path
// TODO: This now exposes a `returned_futures` API, does this make sense for us?
EXLA_ASSIGN_OR_RETURN(per_replica_results, executable_->Execute(input_buffers, options));
}

// EXLA_ASSIGN_OR_RETURN(per_replica_results, executable_->Execute(input_buffers, options));

// sanity check
if (per_replica_results.size() != num_replicas) {
Expand Down Expand Up @@ -482,7 +482,7 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,
};

EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetStreamExecutorGpuClient(false, allocator_config, nullptr, 0));
xla::GetStreamExecutorGpuClient(false, allocator_config, 0, 0));

return new ExlaClient(std::move(client));
}
Expand Down
1 change: 0 additions & 1 deletion exla/config/runtime.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import Config
config :exla, :clients,
cuda: [platform: :cuda, memory_fraction: 0.8],
rocm: [platform: :rocm, memory_fraction: 0.8],
metal: [platform: :metal],
other_host: [platform: :host]

config :exla, default_client: String.to_atom(System.get_env("EXLA_TARGET", "host"))
Expand Down
62 changes: 56 additions & 6 deletions exla/lib/exla/builder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,75 @@ defmodule EXLA.Builder do
Wrapper around XLA's builder.
"""

alias __MODULE__
alias EXLA.{Computation, Op}
alias EXLA.Computation
alias EXLA.Op
alias EXLA.MLIR.Module, as: M
alias EXLA.MLIR.Type

@enforce_keys [:ref]
defstruct [:ref, :parent, :name]

def new(name, _inputs, _outputs, :xla) do
new(name)
end

def new(name, inputs, outputs, :mlir) do
mlir_arg_types = mlir_type(Enum.map(inputs, &elem(&1, 1)))
mlir_ret_type = mlir_type(outputs)

xla_ret_shape = exla_shape(outputs)

module = M.new()
M.create_function(module, "main", mlir_arg_types, mlir_ret_type, xla_ret_shape)
end

defp mlir_type(input) when is_tuple(input) do
input
|> Tuple.to_list()
|> Enum.map(&mlir_type/1)
end

defp mlir_type(inputs) when is_list(inputs) do
Enum.map(inputs, &mlir_type/1)
end

defp mlir_type(%EXLA.Shape{} = shape) do
Type.new(shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move the creation of Type.new fully to the C code right now. Have the C function accept EXLA.Shape and convert to Type there, get rid of the Type module. Otherwise it will be confusing keeping both mlir_ret_type and xla_ret_shape. We just need one type and the easier for now is EXLA.Shape. We can revisit later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

end

defp mlir_type(%Nx.Tensor{} = t) do
t.type
|> EXLA.Shape.make_shape(t.shape)
|> Type.new()
end

defp exla_shape(%Nx.Tensor{} = t) do
EXLA.Shape.make_shape(t.type, t.shape)
end

def new(name) when is_binary(name) do
{:ok, ref} = EXLA.NIF.new_builder(name)
%Builder{ref: ref, parent: nil, name: name}
%__MODULE__{ref: ref, parent: nil, name: name}
end
josevalim marked this conversation as resolved.
Show resolved Hide resolved

def new(builder = %Builder{ref: ref}, name) when is_binary(name) do
def new(builder = %__MODULE__{ref: ref}, name) when is_binary(name) do
{:ok, ref} = EXLA.NIF.create_sub_builder(ref, name)
%Builder{ref: ref, parent: builder, name: name}
%__MODULE__{ref: ref, parent: builder, name: name}
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this is used in nested scopes such as if, cond. So I imagine we will add another branch in the future. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I only implemented the minimal changes to get things to work. The new branch would have been untested, so I didn't bother


def build(root = %Op{}) do
def build(%Op{} = root) do
shape = EXLA.Op.get_shape(root)
{:ok, ref} = EXLA.NIF.build(root.builder, root.ref)
%Computation{ref: ref, output_shape: shape}
end

def build(%EXLA.MLIR.Value{} = val) do
%EXLA.MLIR.Value{function: function, ref: root_ref} =
EXLA.MLIR.Value.get_tuple_element(val, 0)
polvalente marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to deal with this in a different manner


%EXLA.MLIR.Function{ref: function_ref, module: %EXLA.MLIR.Module{ref: module_ref}} = function
:ok = EXLA.NIF.mlir_build(function_ref, root_ref)
# EXLA.NIF.dump_mlir_module(module_ref)
polvalente marked this conversation as resolved.
Show resolved Hide resolved
function
end
end
10 changes: 0 additions & 10 deletions exla/lib/exla/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ defmodule EXLA.Client do
:tpu ->
EXLA.NIF.get_tpu_client()

:metal ->
# TODO: Is this really where/how we want to do this?
:ok =
EXLA.NIF.load_pjrt_plugin(
"METAL",
"/opt/homebrew/lib/python3.10/site-packages/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib"
)

EXLA.NIF.get_c_api_client("METAL")

_ ->
raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}"
end
Expand Down
19 changes: 18 additions & 1 deletion exla/lib/exla/computation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ defmodule EXLA.Computation do
* `:num_partitions` - the number of partitions this computation will run on.

"""
def compile(computation = %Computation{}, client = %Client{}, argument_shapes, options \\ []) do
def compile(computation, client, argument_shapes, options \\ [])

def compile(computation = %Computation{}, client = %Client{}, argument_shapes, options) do
num_replicas = Keyword.get(options, :num_replicas, 1)
num_partitions = Keyword.get(options, :num_partitions, 1)

Expand Down Expand Up @@ -70,6 +72,21 @@ defmodule EXLA.Computation do
}
end

def compile(
%EXLA.MLIR.Function{module: module, xla_return_shape: ret_shape},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just call it return_shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we're removing MLIR-specific things like MLIR.Type, yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed MLIR.Type, and this is now just return_shape :)

client,
arg_shapes,
_opts
) do
EXLA.MLIR.Module.compile(
module,
client,
arg_shapes,
# TO-DO (mlir): do not hardcode this single-item tuple output type
EXLA.Shape.make_tuple_shape([ret_shape])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim we need to think about how to deal with this. EXLA wraps the result in a tuple automatically, so I just hardcoded this output value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue in always making it a tuple? A single-element tuple has no overhead afaik and it helped address bugs between CPU and GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly an issue. The problem right now is that ret_shape was taken from a tensor-return. I assume that returning a tuple or another container would need something else for actually handling here.

polvalente marked this conversation as resolved.
Show resolved Hide resolved
)
end

defp assert_output_shape!(%{output_shape: output_shape}) do
if root_tuple_only?(output_shape) do
output_shape
Expand Down
8 changes: 7 additions & 1 deletion exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ defmodule EXLA.Defn do
end)

inputs_and_shapes = Enum.reverse(reverse_inputs_and_shapes)
builder = EXLA.Builder.new(inspect(key))
mode = options[:compiler_mode] || :xla
builder = EXLA.Builder.new(inspect(key), inputs_and_shapes, outputs, mode)

outfeed =
outfeed
Expand Down Expand Up @@ -844,6 +845,11 @@ defmodule EXLA.Defn do
@bin_op [:add, :subtract, :multiply, :min, :max, :remainder, :pow, :divide, :atan2] ++
[:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift]

defp to_operator(op, [%EXLA.MLIR.Value{} = left, %EXLA.MLIR.Value{} = right], out, state)
when op in @bin_op do
apply(EXLA.MLIR.Value, op, [left, right])
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, we just need to add the MLIR NIFs for all the @bin_ops and they'll be ready :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should leave it to after this and sm-metal are merged to main


defp to_operator(op, [left, right], %{type: type}, _state) when op in @bin_op do
dims = broadcast_axes(op_shape(left), op_shape(right))
apply(EXLA.Op, op, [to_type(left, type), to_type(right, type), dims])
Expand Down
1 change: 0 additions & 1 deletion exla/lib/exla/lib.ex
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ defmodule EXLA.Lib do

defp min_binary({:pred, 8}), do: <<0>>
defp min_binary(type), do: Nx.Type.min_binary(type)

defp max_binary({:pred, 8}), do: <<1>>
defp max_binary(type), do: Nx.Type.max_binary(type)

Expand Down
2 changes: 1 addition & 1 deletion exla/lib/exla/mlir/function.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule EXLA.MLIR.Function do
@moduledoc """
Representation of an MLIR Function or `func.func` type.
"""
defstruct [:module, :ref, :name]
defstruct [:module, :ref, :name, :xla_return_shape]

alias __MODULE__, as: Function
alias EXLA.MLIR.Module
Expand Down
16 changes: 11 additions & 5 deletions exla/lib/exla/mlir/module.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ defmodule EXLA.MLIR.Module do
Creates a new MLIR function with the given name belonging
to the given MLIR module.
"""
def create_function(%Module{ref: module_ref} = module, name, arg_types, %Type{
dims: ret_dims,
type: ret_type
})
def create_function(
%Module{ref: module_ref} = module,
name,
arg_types,
%Type{
dims: ret_dims,
type: ret_type
} = return_type,
xla_ret_shape
)
when is_binary(name) do
nif_arg_types =
Enum.map(arg_types, fn %Type{dims: dims, type: type} ->
Expand All @@ -37,7 +43,7 @@ defmodule EXLA.MLIR.Module do
ref =
EXLA.NIF.create_mlir_function(module_ref, name, nif_arg_types, nif_ret_type) |> unwrap!()

%Function{module: module, ref: ref, name: name}
%Function{module: module, ref: ref, name: name, xla_return_shape: xla_ret_shape}
end

@doc """
Expand Down
3 changes: 2 additions & 1 deletion exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ defmodule EXLA.MLIR.Value do
%Value{ref: ref, function: func}
end

def get_tuple_element(%Value{function: %Function{} = func, ref: ref}, index) when is_integer(index) do
def get_tuple_element(%Value{function: %Function{} = func, ref: ref}, index)
when is_integer(index) do
ref = EXLA.NIF.mlir_get_tuple_element(func.ref, ref, index) |> unwrap!()
%Value{ref: ref, function: func}
end
Expand Down
21 changes: 21 additions & 0 deletions exla/lib/exla/op.ex
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ defmodule EXLA.Op do
%Op{builder: builder, ref: ref}
end

def parameter(%EXLA.MLIR.Function{} = function, i, _shape, _name) do
function
|> EXLA.MLIR.Function.get_arguments()
|> Enum.fetch!(i)
end

@doc """
Builds a tuple with the given elements.
"""
Expand All @@ -74,6 +80,12 @@ defmodule EXLA.Op do
%Op{builder: builder, ref: ref}
end

def tuple(%EXLA.MLIR.Function{} = function, elements) when is_list(elements) do
elements
|> Enum.map(fn %{function: ^function} = e -> e end)
|> EXLA.MLIR.Value.tuple()
end

@doc """
Creates tensor with normal distribution.
"""
Expand Down Expand Up @@ -226,6 +238,10 @@ defmodule EXLA.Op do
%{op | ref: ref}
end

def get_tuple_element(%EXLA.MLIR.Value{} = operand, index) when is_integer(index) do
EXLA.MLIR.Value.get_tuple_element(operand, index)
end

def conditional(
%Op{builder: builder, ref: pred},
%Op{builder: builder, ref: true_op},
Expand Down Expand Up @@ -803,6 +819,11 @@ defmodule EXLA.Op do
%Op{builder: builder, ref: ref}
end

def create_token(%EXLA.MLIR.Function{} = function) do
# TO-DO (mlir): actually do something here
function
end

def create_token(%Builder{ref: builder}) do
ref = EXLA.NIF.create_token(builder) |> unwrap!()
%Op{builder: builder, ref: ref}
Expand Down
2 changes: 1 addition & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ defmodule EXLA.MixProject do
# {:nx, "~> 0.5.1"},
{:nx, path: "../nx"},
{:telemetry, "~> 0.4.0 or ~> 1.0"},
{:xla, "~> 0.4.4", runtime: false},
{:xla, "~> 0.5.0", runtime: false},
{:elixir_make, "~> 0.6", runtime: false},
{:benchee, "~> 1.0", only: :dev},
{:ex_doc, "~> 0.29.0", only: :docs}
Expand Down
4 changes: 2 additions & 2 deletions exla/mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
"earmark_parser": {:hex, :earmark_parser, "1.4.31", "a93921cdc6b9b869f519213d5bc79d9e218ba768d7270d46fdcf1c01bacff9e2", [:mix], [], "hexpm", "317d367ee0335ef037a87e46c91a2269fef6306413f731e8ec11fc45a7efd059"},
"elixir_make": {:hex, :elixir_make, "0.7.4", "5439110c964ffdd8212ca919b5b8beac423085a77ad33d5e394abe812c2d2d75", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "70c33052f7b00c813fd66d15a3cf1f7d1e122860c572ec81b8181b1276074157"},
"elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"},
"ex_doc": {:hex, :ex_doc, "0.29.3", "f07444bcafb302db86e4f02d8bbcd82f2e881a0dcf4f3e4740e4b8128b9353f7", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3dc6787d7b08801ec3b51e9bd26be5e8826fbf1a17e92d1ebc252e1a1c75bfe1"},
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
Expand All @@ -12,5 +12,5 @@
"nx": {:hex, :nx, "0.5.1", "118134b8c97c2a8f86c87aa8434994c1cbbe139a306b89cca04e08dd46228067", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ceb8fbbe19b3c4252a7188d8b0e059fac9da0f4a4f3bb770fc665fdd0b29f0c5"},
"statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"xla": {:hex, :xla, "0.4.4", "c3a8ed1f579bda949df505e49ff65415c8281d991fbd6ae1d8f3c5d0fd155f54", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "484f3f9011db3c9f1ff1e98eecefd382f3882a07ada540fd58803db1d2dab671"},
"xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"},
}
4 changes: 2 additions & 2 deletions exla/test.exs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
client = EXLA.Client.fetch!(:metal)
client = EXLA.Client.fetch!(:host)

arg_xla_shape = EXLA.Shape.make_shape({:f, 32}, {4, 3, 1})
mlir_arg_types = Enum.map([arg_xla_shape, arg_xla_shape], &EXLA.MLIR.Type.new/1)
Expand Down Expand Up @@ -30,4 +30,4 @@ result
|> EXLA.DeviceBuffer.read()
|> Nx.from_binary(:f32)
|> Nx.reshape({4, 3, 1})
|> IO.inspect
|> IO.inspect(label: "all negative ones")
27 changes: 24 additions & 3 deletions exla/test/exla/mlir/executable_test.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
defmodule EXLA.MLIR.ExecutableTest do
use ExUnit.Case, async: true
use EXLA.Case, async: true

alias EXLA.{BinaryBuffer, DeviceBuffer, Executable}
import EXLAHelpers
test "mvp" do
# TO-DO (mlir): this will probably be reorganized in the end
# This test is being added as an MVP for MLIR compilation

t1 = Nx.broadcast(0.0, {2, 3, 1})
t2 = Nx.broadcast(1.0, {2, 3, 1})

result =
EXLA.jit_apply(
fn t1, t2 ->
t1
|> Nx.add(t2)
|> then(&{Nx.add(t2, &1), Nx.subtract(t2, &1)})
|> then(&elem(&1, 0))
end,
[t1, t2],
compiler_mode: :mlir
)

# expected = {Nx.add(t2, t2), t1}
expected = Nx.add(t2, t2)
assert_equal(result, expected)
end
end
Loading