Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(exla): Host IPC and revamped pointer representation #1531

Merged
merged 7 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
fi

SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/ipc.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o


Expand Down
120 changes: 88 additions & 32 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <sstream>
#include <string>

#include "exla_client.h"
#include "exla_cuda.h"
#include "exla_log_sink.h"
#include "exla_mlir.h"
#include "exla_nif_util.h"
#include "ipc.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "stablehlo/dialect/ChloOps.h"
Expand Down Expand Up @@ -416,34 +418,60 @@ ERL_NIF_TERM get_buffer_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_T
return exla::nif::error(env, "Unable to get device pointer kind.");
}

EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);

EXLA_ASSIGN_OR_RETURN_NIF(std::uintptr_t ptr,
(*buffer)->GetDevicePointer((*client)->client()), env);

std::vector<unsigned char> pointer_vec;
ERL_NIF_TERM out_term;
if (pointer_kind == "local") {
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
for (size_t i = 0; i < sizeof(void*); i++) {
pointer_vec.push_back(bytePtr[i]);
ERL_NIF_TERM ptr_term = enif_make_ulong(env, ptr);
ERL_NIF_TERM size_term = enif_make_ulong(env, device_size);
out_term = enif_make_tuple2(env, ptr_term, size_term);
} else if (pointer_kind == "host_ipc") {
std::ostringstream handle_name_stream;
handle_name_stream << "exla:ipc:" << device_size << ":" << ptr;
std::string handle_name = handle_name_stream.str();
int fd = get_ipc_handle((char*)handle_name.c_str(), device_size);

if (fd == -1) {
return exla::nif::error(env, "Unable to get IPC handle");
}

void* ipc_ptr = open_ipc_handle(fd, device_size);
if (ipc_ptr == nullptr) {
return exla::nif::error(env, "Unable to open IPC handle");
}

memcpy(ipc_ptr, (void*)ptr, device_size);

ErlNifBinary handle_name_bin;
enif_alloc_binary(handle_name.size(), &handle_name_bin);
for (int i = 0; i < handle_name.size(); i++) {
handle_name_bin.data[i] = handle_name[i];
}
ERL_NIF_TERM handle_name_term = enif_make_binary(env, &handle_name_bin);
ERL_NIF_TERM size_term = enif_make_uint64(env, device_size);
ERL_NIF_TERM fd_term = enif_make_int(env, fd);
out_term = enif_make_tuple3(env, handle_name_term, fd_term, size_term);
} else if (pointer_kind == "cuda_ipc") {
auto result = get_cuda_ipc_handle(ptr);
if (result.second) {
return exla::nif::error(env, "Unable to get cuda IPC handle");
}
pointer_vec = result.first;
}
auto pointer_vec = result.first;

EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);

ERL_NIF_TERM handle_list[pointer_vec.size()];
for (int i = 0; i < pointer_vec.size(); i++) {
handle_list[i] = enif_make_uint(env, pointer_vec[i]);
ErlNifBinary handle_bin;
enif_alloc_binary(pointer_vec.size(), &handle_bin);
for (int i = 0; i < pointer_vec.size(); i++) {
handle_bin.data[i] = pointer_vec[i];
}
ERL_NIF_TERM handle_term = enif_make_binary(env, &handle_bin);
ERL_NIF_TERM size_term = enif_make_uint64(env, device_size);
out_term = enif_make_tuple2(env, handle_term, size_term);
}

ERL_NIF_TERM handle_list_term = enif_make_list_from_array(env, handle_list, pointer_vec.size());
ERL_NIF_TERM device_size_term = enif_make_uint64(env, device_size);

return exla::nif::ok(env, enif_make_tuple2(env, handle_list_term, device_size_term));
return exla::nif::ok(env, out_term);
}

ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
Expand All @@ -452,40 +480,68 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
}

exla::ExlaClient** client;
std::vector<int64_t> pointer_vec;
ErlNifBinary cuda_ipc_handle_bin;
int cuda_ipc_handle_size = 0;
xla::Shape shape;
int device_id;
std::string pointer_kind;
void* ptr;
int fd = -1;
std::string memname;

if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
return exla::nif::error(env, "Unable to get client.");
}
if (!exla::nif::get_list(env, argv[1], pointer_vec)) {
return exla::nif::error(env, "Unable to get device pointer.");
}
if (!exla::nif::get_atom(env, argv[2], pointer_kind)) {
if (!exla::nif::get_atom(env, argv[1], pointer_kind)) {
return exla::nif::error(env, "Unable to get device pointer kind.");
}

if (pointer_kind == "cuda_ipc") {
if (!enif_inspect_binary(env, argv[2], &cuda_ipc_handle_bin)) {
return exla::nif::error(env, "Unable to get CUDA IPC handle.");
}
} else if (pointer_kind == "host_ipc") {
const ERL_NIF_TERM* tuple;
int arity;
if (
!enif_get_tuple(env, argv[2], &arity, &tuple) ||
(arity != 2) ||
!exla::nif::get(env, tuple[0], &fd) ||
(fd == -1) ||
!exla::nif::get(env, tuple[1], memname)) {
return exla::nif::error(env, "Unable to get IPC handle.");
}
} else if (pointer_kind == "local") {
int64_t ptr_int;
if (!exla::nif::get(env, argv[2], &ptr_int)) {
return exla::nif::error(env, "Unable to get pointer.");
}

ptr = (void*)ptr_int;
}

if (!exla::nif::get_typespec_as_xla_shape(env, argv[3], &shape)) {
return exla::nif::error(env, "Unable to get shape.");
}
if (!exla::nif::get(env, argv[4], &device_id)) {
return exla::nif::error(env, "Unable to get device ordinal.");
}

void* ptr;
if (pointer_kind == "local") {
if (pointer_vec.size() != sizeof(void*)) {
// This helps prevent segfaults if someone passes an IPC handle instead of
// a local pointer.
return exla::nif::error(env, "Invalid pointer size for selected mode.");
}
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
for (size_t i = 0; i < sizeof(void*); i++) {
bytePtr[i] = pointer_vec[i];
std::function<void()> on_delete_callback = []() {};

if (pointer_kind == "host_ipc") {
size_t device_size = (size_t)xla::ShapeUtil::ByteSizeOf(shape);

ptr = open_ipc_handle(fd, device_size);
if (ptr == nullptr) {
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
}

on_delete_callback = [fd, memname, ptr, device_size]() {
close_ipc_handle(fd, ptr, (char*)memname.c_str(), device_size);
};
} else if (pointer_kind == "cuda_ipc") {
auto result = get_pointer_for_ipc_handle(pointer_vec, device_id);
auto result = get_pointer_for_ipc_handle(cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id);
if (result.second) {
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
}
Expand All @@ -494,8 +550,8 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E

EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env);

std::function<void()> on_delete_callback = []() {};
EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, shape, device, on_delete_callback), env);

exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer(std::move(buffer));
return exla::nif::ok(env, exla::nif::make<exla::ExlaBuffer*>(env, exla_buffer));
}
Expand Down
9 changes: 4 additions & 5 deletions exla/c_src/exla/exla_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t pt
return std::make_pair(result, status != cudaSuccess);
}

std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list, int device_id) {
if (handle_list.size() != sizeof(cudaIpcMemHandle_t)) {
printf("Error: Invalid CUDA IPC memory handle size\n");
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) {
if (handle_size != sizeof(cudaIpcMemHandle_t)) {
return std::make_pair(nullptr, 1); // Return with error status
}

unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)];
for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {
ipc_handle_data[i] = (uint8_t)handle_list[i];
ipc_handle_data[i] = handle_bin[i];
}

cudaIpcMemHandle_t ipc_handle;
Expand All @@ -54,7 +53,7 @@ std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t pt
return std::make_pair(std::vector<unsigned char>(0), 1);
}

std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list, int device_id) {
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) {
return std::make_pair(nullptr, 1);
}
#endif
2 changes: 1 addition & 1 deletion exla/c_src/exla/exla_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
#include <vector>

std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t);
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t>, int);
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t*, size_t, int);
47 changes: 47 additions & 0 deletions exla/c_src/exla/ipc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "ipc.h"

#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

#include <iostream>

// Function to create or open a shared memory object and set its size
int get_ipc_handle(const char* memname, size_t memsize) {
int fd = shm_open(memname, O_CREAT | O_RDWR, 0666);
if (fd == -1) {
return -1;
}

if (ftruncate(fd, memsize) == -1) {
close(fd);
return -1;
}

return fd;
}

// Function to map the shared memory in this process
void* open_ipc_handle(int fd, size_t memsize) {
void* ptr = mmap(NULL, memsize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if (ptr == MAP_FAILED) {
perror("mmap");
return nullptr;
}
return ptr;
}

int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize) {
if (munmap(ptr, memsize) == -1) {
return -1;
}

if (close(fd) == -1) {
return -1;
}

shm_unlink(memname);

return 0;
}
7 changes: 7 additions & 0 deletions exla/c_src/exla/ipc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include <cstddef>

int get_ipc_handle(const char* memname, size_t memsize);
void* open_ipc_handle(int fd, size_t memsize);
int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize);
63 changes: 53 additions & 10 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ defmodule EXLA.Backend do

mode =
case opts[:mode] do
mode when mode in [:local, :cuda_ipc] ->
mode when mode in [:local, :cuda_ipc, :host_ipc] ->
mode

mode ->
raise ArgumentError, "expected one of :local, :cuda_ipc, got: #{inspect(mode)}"
raise ArgumentError,
"expected one of :local, :cuda_ipc, :host_ipc, got: #{inspect(mode)}"
end

case buffer do
Expand All @@ -106,33 +107,75 @@ defmodule EXLA.Backend do
client = EXLA.Client.fetch!(buffer.client_name)

case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do
{:ok, {pointer, _size}} ->
{:ok, pointer}
{:ok, result} ->
handle =
case {result, mode} do
{{ptr, _size}, :local} when is_integer(ptr) ->
# pointer is an integer here
ptr

{{handle_name, fd, size}, :host_ipc} ->
%EXLA.IPCHandle.Host{
name: handle_name,
fd: fd,
size: size
}

{{handle, size}, :cuda_ipc} ->
%EXLA.IPCHandle.CUDA{
handle: handle,
device_id: buffer.device_id,
size: size
}
end

{:ok, handle}

error ->
error
end
end

@impl true
def from_pointer(pointer, type, dims, backend_opts, opts) do
backend_opts = Keyword.validate!(backend_opts, [:client_name, :device_id])
opts = Keyword.validate!(opts, [:names, mode: :local])
def from_pointer(handle, type, dims, backend_opts, opts) do
# mode is inferred from the handle kind
backend_opts = Keyword.validate!(backend_opts, [:client, :device_id])
opts = Keyword.validate!(opts, [:names])

template = Nx.template(dims, type, names: opts[:names])

client_name = backend_opts[:client_name] || EXLA.Client.default_name()
client_name = backend_opts[:client] || EXLA.Client.default_name()
client = EXLA.Client.fetch!(client_name)

device_id = backend_opts[:device_id] || client.default_device_id

typespec = EXLA.Typespec.tensor(type, dims)

num_elements = Tuple.product(dims)
{_, bits} = type
shape_size = num_elements * div(bits, 8)

{mode, handle_nif} =
case handle do
%{size: size} when size != shape_size ->
raise ArgumentError,
"invalid IPC handle size for shape, expected: #{shape_size}, got: #{size}"

%EXLA.IPCHandle.Host{fd: fd, name: name} ->
{:host_ipc, {fd, name}}

%EXLA.IPCHandle.CUDA{handle: handle} ->
{:cuda_ipc, handle}

_ when is_integer(handle) ->
{:local, handle}
end

result =
EXLA.NIF.create_buffer_from_device_pointer(
client.ref,
pointer,
opts[:mode],
mode,
handle_nif,
EXLA.Typespec.nif_encode(typespec),
device_id
)
Expand Down
Loading
Loading