Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add EXLA.to_mlir_module/2 #1497

Merged
merged 7 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,43 @@ defmodule EXLA do
Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA))
end

@doc """
Takes in a function, the templates variables and the compilation options
and returns the `EXLA.Executable` struct.

## Examples

iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
iex> args = [1.0, 2.0]
iex> module = EXLA.to_mlir_module(fun, args)
iex> EXLA.MLIR.Module.as_string(module)
~c\"\"\"
polvalente marked this conversation as resolved.
Show resolved Hide resolved
module {
func.func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.sine %arg0 : tensor<f32>
%1 = stablehlo.cosine %arg1 : tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
stablehlo.return %2 : tensor<f32>
}
}
\"\"\"
"""
def to_mlir_module(function, args, options \\ []) do
comp_fun = fn _key, callback ->
{:ok, {_xla_time, executable, _extra, _outfeed}} = callback.()
throw({:mlir_module, executable.ref})
Comment on lines +380 to +381
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gah, this is relying on internals of another module but I can't think of anything better for now, so ship it.

end

function
|> jit([
{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}},
{:module_compilation, :to_mlir} | options
])
|> apply(args)
catch
{:mlir_module, ref} -> %EXLA.MLIR.Module{ref: ref}
end

@doc """
Checks if the compilation of function with args is cached.

Expand Down
31 changes: 21 additions & 10 deletions exla/lib/exla/mlir/module.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ defmodule EXLA.MLIR.Module do
* `:use_spmd` - enables Single-Program Multiple-Data partioning.
This is set to true if `:num_partitions` is more than one, otherwise is `false`.

* `:module_compilation` - either `:to_mlir` or `:to_pjrt`. The default is `:to_pjrt`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim not sure about the option naming here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is private, so it is fine!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we return %EXLA.MLIR.Module{} in the other function, shouldn't the whole module be public?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. We should return the string only, not the module struct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The public-facing function is now returning the binary directly


* `:to_pjrt` - the `EXLA.Executable` `:ref` field will hold the reference to a PjRt executable.
* `:to_mlir` - the `EXLA.Executable` `:ref` field will hold the reference to an MLIR module.

Currently those options do not have an effect as they related to running the
same compiled executable on multiple replicas.

Expand Down Expand Up @@ -102,16 +107,22 @@ defmodule EXLA.MLIR.Module do
# module |> as_string() |> IO.puts()

ref =
EXLA.NIF.mlir_compile(
client.ref,
module.ref,
Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1),
num_replicas,
num_partitions,
use_spmd,
device_id
)
|> unwrap!()
case Keyword.get(options, :module_compilation, :to_pjrt) do
:to_mlir ->
module.ref

:to_pjrt ->
EXLA.NIF.mlir_compile(
client.ref,
module.ref,
Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1),
num_replicas,
num_partitions,
use_spmd,
device_id
)
|> unwrap!()
end

%Executable{
client: client,
Expand Down
Loading