-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: finish MLIR scaffolding for EXLA.Defn #1293
Conversation
exla/lib/exla/computation.ex
Outdated
module, | ||
client, | ||
arg_shapes, | ||
EXLA.Shape.make_tuple_shape([ret_shape]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josevalim we need to think about how to deal with this. EXLA wraps the result in a tuple automatically, so I just hardcoded this output value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an issue in always making it a tuple? A single-element tuple has no overhead afaik and it helped address bugs between CPU and GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly an issue. The problem right now is that ret_shape was taken from a tensor-return. I assume that returning a tuple or another container would need something else for actually handling here.
def build(%EXLA.MLIR.Value{} = val) do | ||
%EXLA.MLIR.Value{function: function, ref: root_ref} = | ||
# TO-DO: do not hardcode fetching just the first item as the output | ||
EXLA.MLIR.Value.get_tuple_element(val, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably want to deal with this in a different manner
exla/lib/exla/defn.ex
Outdated
defp to_operator(op, [%EXLA.MLIR.Value{} = left, %EXLA.MLIR.Value{} = right], out, state) | ||
when op in @bin_op do | ||
apply(EXLA.MLIR.Value, op, [left, right]) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory, we just need to add the MLIR NIFs for all the @bin_op
s and they'll be ready :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should leave it to after this and sm-metal are merged to main
{:ok, ref} = EXLA.NIF.create_sub_builder(ref, name) | ||
%Builder{ref: ref, parent: builder, name: name} | ||
%__MODULE__{ref: ref, parent: builder, name: name} | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, this is used in nested scopes such as if
, cond
. So I imagine we will add another branch in the future. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I only implemented the minimal changes to get things to work. The new branch would have been untested, so I didn't bother
exla/lib/exla/computation.ex
Outdated
@@ -70,6 +72,21 @@ defmodule EXLA.Computation do | |||
} | |||
end | |||
|
|||
def compile( | |||
%EXLA.MLIR.Function{module: module, xla_return_shape: ret_shape}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just call it return_shape
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we're removing MLIR-specific things like MLIR.Type, yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed MLIR.Type, and this is now just return_shape :)
exla/lib/exla/builder.ex
Outdated
end | ||
|
||
defp mlir_type(%EXLA.Shape{} = shape) do | ||
Type.new(shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move the creation of Type.new fully to the C code right now. Have the C function accept EXLA.Shape and convert to Type there, get rid of the Type module. Otherwise it will be confusing keeping both mlir_ret_type
and xla_ret_shape
. We just need one type and the easier for now is EXLA.Shape. We can revisit later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
e1e78a3
to
b34738c
Compare
8a5754f
to
152d1ee
Compare
No description provided.