Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: finish MLIR scaffolding for EXLA.Defn #1293

Merged
merged 16 commits into from
Aug 31, 2023
Merged

Conversation

polvalente
Copy link
Contributor

No description provided.

@polvalente polvalente self-assigned this Aug 30, 2023
@polvalente polvalente changed the base branch from main to sm-metal August 30, 2023 19:34
@polvalente polvalente changed the title feat: finish MLIR scaffolding for Nx.Defn feat: finish MLIR scaffolding for EXLA.Defn Aug 30, 2023
module,
client,
arg_shapes,
EXLA.Shape.make_tuple_shape([ret_shape])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim we need to think about how to deal with this. EXLA wraps the result in a tuple automatically, so I just hardcoded this output value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue in always making it a tuple? A single-element tuple has no overhead afaik and it helped address bugs between CPU and GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly an issue. The problem right now is that ret_shape was taken from a tensor-return. I assume that returning a tuple or another container would need something else for actually handling here.

exla/lib/exla/builder.ex Outdated Show resolved Hide resolved
def build(%EXLA.MLIR.Value{} = val) do
%EXLA.MLIR.Value{function: function, ref: root_ref} =
# TO-DO: do not hardcode fetching just the first item as the output
EXLA.MLIR.Value.get_tuple_element(val, 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to deal with this in a different manner

Comment on lines 848 to 851
defp to_operator(op, [%EXLA.MLIR.Value{} = left, %EXLA.MLIR.Value{} = right], out, state)
when op in @bin_op do
apply(EXLA.MLIR.Value, op, [left, right])
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, we just need to add the MLIR NIFs for all the @bin_ops and they'll be ready :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should leave it to after this and sm-metal are merged to main

{:ok, ref} = EXLA.NIF.create_sub_builder(ref, name)
%Builder{ref: ref, parent: builder, name: name}
%__MODULE__{ref: ref, parent: builder, name: name}
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this is used in nested scopes such as if, cond. So I imagine we will add another branch in the future. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I only implemented the minimal changes to get things to work. The new branch would have been untested, so I didn't bother

@@ -70,6 +72,21 @@ defmodule EXLA.Computation do
}
end

def compile(
%EXLA.MLIR.Function{module: module, xla_return_shape: ret_shape},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just call it return_shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we're removing MLIR-specific things like MLIR.Type, yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed MLIR.Type, and this is now just return_shape :)

end

defp mlir_type(%EXLA.Shape{} = shape) do
Type.new(shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move the creation of Type.new fully to the C code right now. Have the C function accept EXLA.Shape and convert to Type there, get rid of the Type module. Otherwise it will be confusing keeping both mlir_ret_type and xla_ret_shape. We just need one type and the easier for now is EXLA.Shape. We can revisit later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@polvalente polvalente merged commit 6c021a6 into sm-metal Aug 31, 2023
9 checks passed
@polvalente polvalente deleted the pv-mlir-scaffolding branch August 31, 2023 07:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants