-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: finish MLIR scaffolding for EXLA.Defn #1293
Changes from 4 commits
f6e1621
08f9f3a
e7d3da4
edc31e9
134c100
17a7dbf
c5a374c
4d97bda
e9c4f39
b34738c
152d1ee
3cb1c16
ac82942
5352893
86a1a95
75b8b8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,25 +3,75 @@ defmodule EXLA.Builder do | |
Wrapper around XLA's builder. | ||
""" | ||
|
||
alias __MODULE__ | ||
alias EXLA.{Computation, Op} | ||
alias EXLA.Computation | ||
alias EXLA.Op | ||
alias EXLA.MLIR.Module, as: M | ||
alias EXLA.MLIR.Type | ||
|
||
@enforce_keys [:ref] | ||
defstruct [:ref, :parent, :name] | ||
|
||
def new(name, _inputs, _outputs, :xla) do | ||
new(name) | ||
end | ||
|
||
def new(name, inputs, outputs, :mlir) do | ||
mlir_arg_types = mlir_type(Enum.map(inputs, &elem(&1, 1))) | ||
mlir_ret_type = mlir_type(outputs) | ||
|
||
xla_ret_shape = exla_shape(outputs) | ||
|
||
module = M.new() | ||
M.create_function(module, "main", mlir_arg_types, mlir_ret_type, xla_ret_shape) | ||
end | ||
|
||
defp mlir_type(input) when is_tuple(input) do | ||
input | ||
|> Tuple.to_list() | ||
|> Enum.map(&mlir_type/1) | ||
end | ||
|
||
defp mlir_type(inputs) when is_list(inputs) do | ||
Enum.map(inputs, &mlir_type/1) | ||
end | ||
|
||
defp mlir_type(%EXLA.Shape{} = shape) do | ||
Type.new(shape) | ||
end | ||
|
||
defp mlir_type(%Nx.Tensor{} = t) do | ||
t.type | ||
|> EXLA.Shape.make_shape(t.shape) | ||
|> Type.new() | ||
end | ||
|
||
defp exla_shape(%Nx.Tensor{} = t) do | ||
EXLA.Shape.make_shape(t.type, t.shape) | ||
end | ||
|
||
def new(name) when is_binary(name) do | ||
{:ok, ref} = EXLA.NIF.new_builder(name) | ||
%Builder{ref: ref, parent: nil, name: name} | ||
%__MODULE__{ref: ref, parent: nil, name: name} | ||
end | ||
josevalim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def new(builder = %Builder{ref: ref}, name) when is_binary(name) do | ||
def new(builder = %__MODULE__{ref: ref}, name) when is_binary(name) do | ||
{:ok, ref} = EXLA.NIF.create_sub_builder(ref, name) | ||
%Builder{ref: ref, parent: builder, name: name} | ||
%__MODULE__{ref: ref, parent: builder, name: name} | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, this is used in nested scopes such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I only implemented the minimal changes to get things to work. The new branch would have been untested, so I didn't bother |
||
|
||
def build(root = %Op{}) do | ||
def build(%Op{} = root) do | ||
shape = EXLA.Op.get_shape(root) | ||
{:ok, ref} = EXLA.NIF.build(root.builder, root.ref) | ||
%Computation{ref: ref, output_shape: shape} | ||
end | ||
|
||
def build(%EXLA.MLIR.Value{} = val) do | ||
%EXLA.MLIR.Value{function: function, ref: root_ref} = | ||
EXLA.MLIR.Value.get_tuple_element(val, 0) | ||
polvalente marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we probably want to deal with this in a different manner |
||
|
||
%EXLA.MLIR.Function{ref: function_ref, module: %EXLA.MLIR.Module{ref: module_ref}} = function | ||
:ok = EXLA.NIF.mlir_build(function_ref, root_ref) | ||
# EXLA.NIF.dump_mlir_module(module_ref) | ||
polvalente marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,9 @@ defmodule EXLA.Computation do | |
* `:num_partitions` - the number of partitions this computation will run on. | ||
|
||
""" | ||
def compile(computation = %Computation{}, client = %Client{}, argument_shapes, options \\ []) do | ||
def compile(computation, client, argument_shapes, options \\ []) | ||
|
||
def compile(computation = %Computation{}, client = %Client{}, argument_shapes, options) do | ||
num_replicas = Keyword.get(options, :num_replicas, 1) | ||
num_partitions = Keyword.get(options, :num_partitions, 1) | ||
|
||
|
@@ -70,6 +72,21 @@ defmodule EXLA.Computation do | |
} | ||
end | ||
|
||
def compile( | ||
%EXLA.MLIR.Function{module: module, xla_return_shape: ret_shape}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just call it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that we're removing MLIR-specific things like MLIR.Type, yes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed MLIR.Type, and this is now just return_shape :) |
||
client, | ||
arg_shapes, | ||
_opts | ||
) do | ||
EXLA.MLIR.Module.compile( | ||
module, | ||
client, | ||
arg_shapes, | ||
# TO-DO (mlir): do not hardcode this single-item tuple output type | ||
EXLA.Shape.make_tuple_shape([ret_shape]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @josevalim we need to think about how to deal with this. EXLA wraps the result in a tuple automatically, so I just hardcoded this output value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an issue in always making it a tuple? A single-element tuple has no overhead afaik and it helped address bugs between CPU and GPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not exactly an issue. The problem right now is that ret_shape was taken from a tensor-return. I assume that returning a tuple or another container would need something else for actually handling here.
polvalente marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
end | ||
|
||
defp assert_output_shape!(%{output_shape: output_shape}) do | ||
if root_tuple_only?(output_shape) do | ||
output_shape | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -414,7 +414,8 @@ defmodule EXLA.Defn do | |
end) | ||
|
||
inputs_and_shapes = Enum.reverse(reverse_inputs_and_shapes) | ||
builder = EXLA.Builder.new(inspect(key)) | ||
mode = options[:compiler_mode] || :xla | ||
builder = EXLA.Builder.new(inspect(key), inputs_and_shapes, outputs, mode) | ||
|
||
outfeed = | ||
outfeed | ||
|
@@ -844,6 +845,11 @@ defmodule EXLA.Defn do | |
@bin_op [:add, :subtract, :multiply, :min, :max, :remainder, :pow, :divide, :atan2] ++ | ||
[:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift] | ||
|
||
defp to_operator(op, [%EXLA.MLIR.Value{} = left, %EXLA.MLIR.Value{} = right], out, state) | ||
when op in @bin_op do | ||
apply(EXLA.MLIR.Value, op, [left, right]) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In theory, we just need to add the MLIR NIFs for all the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should leave it to after this and sm-metal are merged to main |
||
|
||
defp to_operator(op, [left, right], %{type: type}, _state) when op in @bin_op do | ||
dims = broadcast_axes(op_shape(left), op_shape(right)) | ||
apply(EXLA.Op, op, [to_type(left, type), to_type(right, type), dims]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,27 @@ | ||
defmodule EXLA.MLIR.ExecutableTest do | ||
use ExUnit.Case, async: true | ||
use EXLA.Case, async: true | ||
|
||
alias EXLA.{BinaryBuffer, DeviceBuffer, Executable} | ||
import EXLAHelpers | ||
test "mvp" do | ||
# TO-DO (mlir): this will probably be reorganized in the end | ||
# This test is being added as an MVP for MLIR compilation | ||
|
||
t1 = Nx.broadcast(0.0, {2, 3, 1}) | ||
t2 = Nx.broadcast(1.0, {2, 3, 1}) | ||
|
||
result = | ||
EXLA.jit_apply( | ||
fn t1, t2 -> | ||
t1 | ||
|> Nx.add(t2) | ||
|> then(&{Nx.add(t2, &1), Nx.subtract(t2, &1)}) | ||
|> then(&elem(&1, 0)) | ||
end, | ||
[t1, t2], | ||
compiler_mode: :mlir | ||
) | ||
|
||
# expected = {Nx.add(t2, t2), t1} | ||
expected = Nx.add(t2, t2) | ||
assert_equal(result, expected) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move the creation of Type.new fully to the C code right now. Have the C function accept EXLA.Shape and convert to Type there, get rid of the Type module. Otherwise it will be confusing keeping both
mlir_ret_type
andxla_ret_shape
. We just need one type and the easier for now is EXLA.Shape. We can revisit later.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!