Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(EXLA): Add support for MLIR compilation #1247

Merged
merged 7 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)
# Build flags
CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -O3 -Wall -Wno-sign-compare \
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
-shared -std=c++17 -w -DLLVM_ON_UNIX=1
-shared -std=c++20 -w -DLLVM_ON_UNIX=1 -DLLVM_VERSION_STRING=

LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension

ifeq ($(shell uname -s), Darwin)
LDFLAGS += -flat_namespace -undefined suppress
POST_INSTALL = install_name_tool \
-change bazel-out/darwin_arm64-opt/bin/tensorflow/compiler/xla/extension/libxla_extension.so @loader_path/xla_extension/lib/libxla_extension.so \
-change bazel-out/darwin-opt/bin/tensorflow/compiler/xla/extension/libxla_extension.so @loader_path/xla_extension/lib/libxla_extension.so \
-change bazel-out/darwin_arm64-opt/bin/xla/extension/libxla_extension.so @loader_path/xla_extension/lib/libxla_extension.so \
-change bazel-out/darwin-opt/bin/xla/extension/libxla_extension.so @loader_path/xla_extension/lib/libxla_extension.so \
$(EXLA_CACHE_SO)
else
# Use a relative RPATH, so at runtime libexla.so looks for libxla_extension.so
Expand Down
76 changes: 59 additions & 17 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
#include "exla_nif_util.h"
#include "exla_client.h"
#include "exla_log_sink.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include "tensorflow/compiler/xla/client/lib/svd.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "xla/service/platform_util.h"
#include "xla/shape_util.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/client/client.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/lu_decomposition.h"
#include "xla/client/lib/qr.h"
#include "xla/client/lib/self_adjoint_eig.h"
#include "xla/client/lib/svd.h"
#include "xla/client/lib/sorting.h"
#include "xla/primitive_util.h"
#include "xla/pjrt/pjrt_api.h"

// All of these are created with calls to `new` and subsequently
// passed to the VM as pointers-to-pointers so we balance it out
Expand Down Expand Up @@ -2064,7 +2065,7 @@ ERL_NIF_TERM transfer_to_infeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM arg
xla::Status transfer_status = (*client)->TransferToInfeed(env, terms[0], *shape, device_id);

if(!transfer_status.ok()) {
return exla::nif::error(env, transfer_status.error_message().c_str());
return exla::nif::error(env, transfer_status.message().data());
}

data = tail;
Expand Down Expand Up @@ -2107,7 +2108,7 @@ ERL_NIF_TERM transfer_from_outfeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM

if (!statusor.ok()) {
enif_clear_env(penv);
return exla::nif::error(env, statusor.status().error_message().c_str());
return exla::nif::error(env, statusor.status().message().data());
}

ERL_NIF_TERM msg = std::move(statusor.value());
Expand Down Expand Up @@ -2191,6 +2192,45 @@ ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

std::string device_type;
if (!exla::nif::get(env, argv[0], device_type)) {
return exla::nif::error(env, "Unable to get device type.");
}

EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient* client, exla::GetCApiClient(device_type), env);

return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM load_pjrt_plugin(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 2) {
return exla::nif::error(env, "Bad argument count.");
}

std::string device_type;
std::string library_path;
if (!exla::nif::get(env, argv[0], device_type)) {
return exla::nif::error(env, "Unable to get device type.");
}
if (!exla::nif::get(env, argv[1], library_path)) {
return exla::nif::error(env, "Unable to get library path.");
}

xla::Status result = pjrt::LoadPjrtPlugin(device_type, library_path);

if (!result.ok()) {
return exla::nif::error(env, result.message().data());
}
else {
return exla::nif::ok(env);
}
}

ERL_NIF_TERM get_device_count(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
Expand Down Expand Up @@ -2326,11 +2366,11 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
exla::ExlaLogSink* sink = new exla::ExlaLogSink(logger_pid);

// NO_DEFAULT_LOGGER doesn't behave right
for (auto *log_sink : tensorflow::TFGetLogSinks()) {
tensorflow::TFRemoveLogSink(log_sink);
for (auto *log_sink : tsl::TFGetLogSinks()) {
tsl::TFRemoveLogSink(log_sink);
}

tensorflow::TFAddLogSink(sink);
tsl::TFAddLogSink(sink);

return exla::nif::ok(env);
}
Expand All @@ -2345,6 +2385,8 @@ static ErlNifFunc exla_funcs[] = {
{"get_host_client", 0, get_host_client},
{"get_gpu_client", 2, get_gpu_client},
{"get_tpu_client", 0, get_tpu_client},
{"get_c_api_client", 1, get_c_api_client},
{"load_pjrt_plugin", 2, load_pjrt_plugin},
{"get_device_count", 1, get_device_count},
{"get_supported_platforms", 0, get_supported_platforms},
{"compile", 7, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND},
Expand Down
18 changes: 13 additions & 5 deletions exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#include "exla_client.h"
#include "exla_nif_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "xla/layout_util.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/tpu_client.h"

namespace exla {

Expand Down Expand Up @@ -458,4 +459,11 @@ xla::StatusOr<ExlaClient*> GetTpuClient() {

return new ExlaClient(std::move(client));
}

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type) {
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetCApiClient(device_type));

return new ExlaClient(std::move(client));
}
} // namespace exla
10 changes: 6 additions & 4 deletions exla/c_src/exla/exla_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#include <utility>

#include "exla_nif_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tsl/platform/types.h"
#include "tsl/platform/status.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/pjrt_client.h"

// The implementations in this module are designed after implementations
// in the XLA runtime, PjRt. Deviations are made where it makes sense
Expand Down Expand Up @@ -95,6 +95,8 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,
xla::GpuAllocatorConfig::Kind kind);

xla::StatusOr<ExlaClient*> GetTpuClient();

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type);
} // namespace exla

#endif
6 changes: 3 additions & 3 deletions exla/c_src/exla/exla_log_sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
#include <string>

#include "exla_nif_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tsl/platform/logging.h"
#include "absl/base/log_severity.h"

namespace exla {

// Redirects calls to logging to the Elixir Logger. `sink_pid`
// is the PID for a GenServer in Elixir which receives messages
// with logging information on every call to `LOG(severity)`.
class ExlaLogSink : public tensorflow::TFLogSink {
class ExlaLogSink : public tsl::TFLogSink {
public:
explicit ExlaLogSink(ErlNifPid sink_pid) : sink_pid_(sink_pid) {
// Logger Env
Expand Down Expand Up @@ -45,7 +45,7 @@ class ExlaLogSink : public tensorflow::TFLogSink {
return enif_make_tuple4(env_, status, msg, file, line_term);
}

void Send(const tensorflow::TFLogEntry& entry) {
void Send(const tsl::TFLogEntry& entry) {
ERL_NIF_TERM msg;
std::string msg_str = entry.ToString();
std::string fname = entry.FName();
Expand Down
4 changes: 2 additions & 2 deletions exla/c_src/exla/exla_nif_util.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "exla_nif_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "xla/shape_util.h"
#include "xla/primitive_util.h"

namespace exla {
namespace nif {
Expand Down
10 changes: 5 additions & 5 deletions exla/c_src/exla/exla_nif_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <map>

#include "erl_nif.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/shape.h"
#include "xla/xla_data.pb.h"
#include "xla/types.h"
#include "xla/shape.h"

#if !defined(__GNUC__) && (defined(__WIN32__) || defined(_WIN32) || defined(_WIN32_))
typedef unsigned __int64 nif_uint64_t;
Expand Down Expand Up @@ -346,7 +346,7 @@ ERL_NIF_TERM make_shape_info(ErlNifEnv* env, xla::Shape shape);
#define EXLA_EFFECT_OR_RETURN_NIF_IMPL(status, rexpr, env) \
auto status = (rexpr); \
if (!status.ok()) { \
return exla::nif::error(env, status.error_message().c_str()); \
return exla::nif::error(env, status.message().data()); \
}

// Macro to be used to consume Status from within a NIF.
Expand All @@ -372,7 +372,7 @@ ERL_NIF_TERM make_shape_info(ErlNifEnv* env, xla::Shape shape);
#define EXLA_ASSIGN_OR_RETURN_NIF_IMPL(statusor, lhs, rexpr, env) \
auto statusor = (rexpr); \
if (!statusor.ok()) { \
return exla::nif::error(env, statusor.status().error_message().c_str()); \
return exla::nif::error(env, statusor.status().message().data()); \
} \
lhs = std::move(statusor.value());

Expand Down
1 change: 1 addition & 0 deletions exla/config/runtime.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Config
config :exla, :clients,
cuda: [platform: :cuda, memory_fraction: 0.8],
rocm: [platform: :rocm, memory_fraction: 0.8],
metal: [platform: :metal],
other_host: [platform: :host]

config :exla, default_client: String.to_atom(System.get_env("EXLA_TARGET", "host"))
Expand Down
5 changes: 5 additions & 0 deletions exla/lib/exla/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ defmodule EXLA.Client do
:tpu ->
EXLA.NIF.get_tpu_client()

:metal ->
# TODO: Is this really where/how we want to do this?
:ok = EXLA.NIF.load_pjrt_plugin("METAL", "/opt/homebrew/lib/python3.10/site-packages/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be linked in the respective release for elixir-nx/xla directly, and then you can have something like this coming from there:

nx/torchx/c_src/torchx.cpp

Lines 347 to 355 in 219c23e

NIF(mps_is_available)
{
#ifdef MAC_ARM64
bool has_mps = at::hasMPS();
#else
bool has_mps = false;
#endif
return nx::nif::make(env, has_mps);
}

Copy link
Contributor

@polvalente polvalente Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Torchx, this flag is defined here:

nx/torchx/CMakeLists.txt

Lines 46 to 48 in 219c23e

if(ARM64_SUPPORTED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMAC_ARM64")
endif()

I assume you'd do something similar for :xla

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to keep doing it on Elixir land: we can load the plugin on EXLA.Application. I also agree that ideally we would bundle the plugin with the precompiled XLA binary for macOS.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's going to work here because the only way to register the plugin as far as I can tell is dynamically through PjRt

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that we ship the plugin with XLA and in here we point to XLA priv dir to load it. Would that work?

EXLA.NIF.get_c_api_client("METAL")

_ ->
raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}"
end
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,8 @@ defmodule EXLA.NIF do

def start_log_sink(_sink_pid),
do: :erlang.nif_error(:undef)

def get_c_api_client(_device_type), do: :erlang.nif_error(:undef)

def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef)
end