Skip to content

Commit

Permalink
feat(exla): Host IPC and revamped pointer representation (#1531)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Sep 14, 2024
1 parent 116b124 commit 93e4383
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 76 deletions.
4 changes: 2 additions & 2 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
fi

SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/ipc.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o


Expand Down
120 changes: 88 additions & 32 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <sstream>
#include <string>

#include "exla_client.h"
#include "exla_cuda.h"
#include "exla_log_sink.h"
#include "exla_mlir.h"
#include "exla_nif_util.h"
#include "ipc.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "stablehlo/dialect/ChloOps.h"
Expand Down Expand Up @@ -449,34 +451,60 @@ ERL_NIF_TERM get_buffer_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_T
return exla::nif::error(env, "Unable to get device pointer kind.");
}

EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);

EXLA_ASSIGN_OR_RETURN_NIF(std::uintptr_t ptr,
(*buffer)->GetDevicePointer((*client)->client()), env);

std::vector<unsigned char> pointer_vec;
ERL_NIF_TERM out_term;
if (pointer_kind == "local") {
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
for (size_t i = 0; i < sizeof(void*); i++) {
pointer_vec.push_back(bytePtr[i]);
ERL_NIF_TERM ptr_term = enif_make_ulong(env, ptr);
ERL_NIF_TERM size_term = enif_make_ulong(env, device_size);
out_term = enif_make_tuple2(env, ptr_term, size_term);
} else if (pointer_kind == "host_ipc") {
std::ostringstream handle_name_stream;
handle_name_stream << "exla:ipc:" << device_size << ":" << ptr;
std::string handle_name = handle_name_stream.str();
int fd = get_ipc_handle((char*)handle_name.c_str(), device_size);

if (fd == -1) {
return exla::nif::error(env, "Unable to get IPC handle");
}

void* ipc_ptr = open_ipc_handle(fd, device_size);
if (ipc_ptr == nullptr) {
return exla::nif::error(env, "Unable to open IPC handle");
}

memcpy(ipc_ptr, (void*)ptr, device_size);

ErlNifBinary handle_name_bin;
enif_alloc_binary(handle_name.size(), &handle_name_bin);
for (int i = 0; i < handle_name.size(); i++) {
handle_name_bin.data[i] = handle_name[i];
}
ERL_NIF_TERM handle_name_term = enif_make_binary(env, &handle_name_bin);
ERL_NIF_TERM size_term = enif_make_uint64(env, device_size);
ERL_NIF_TERM fd_term = enif_make_int(env, fd);
out_term = enif_make_tuple3(env, handle_name_term, fd_term, size_term);
} else if (pointer_kind == "cuda_ipc") {
auto result = get_cuda_ipc_handle(ptr);
if (result.second) {
return exla::nif::error(env, "Unable to get cuda IPC handle");
}
pointer_vec = result.first;
}
auto pointer_vec = result.first;

EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);

ERL_NIF_TERM handle_list[pointer_vec.size()];
for (int i = 0; i < pointer_vec.size(); i++) {
handle_list[i] = enif_make_uint(env, pointer_vec[i]);
ErlNifBinary handle_bin;
enif_alloc_binary(pointer_vec.size(), &handle_bin);
for (int i = 0; i < pointer_vec.size(); i++) {
handle_bin.data[i] = pointer_vec[i];
}
ERL_NIF_TERM handle_term = enif_make_binary(env, &handle_bin);
ERL_NIF_TERM size_term = enif_make_uint64(env, device_size);
out_term = enif_make_tuple2(env, handle_term, size_term);
}

ERL_NIF_TERM handle_list_term = enif_make_list_from_array(env, handle_list, pointer_vec.size());
ERL_NIF_TERM device_size_term = enif_make_uint64(env, device_size);

return exla::nif::ok(env, enif_make_tuple2(env, handle_list_term, device_size_term));
return exla::nif::ok(env, out_term);
}

ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
Expand All @@ -485,40 +513,68 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
}

exla::ExlaClient** client;
std::vector<int64_t> pointer_vec;
ErlNifBinary cuda_ipc_handle_bin;
int cuda_ipc_handle_size = 0;
xla::Shape shape;
int device_id;
std::string pointer_kind;
void* ptr;
int fd = -1;
std::string memname;

if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
return exla::nif::error(env, "Unable to get client.");
}
if (!exla::nif::get_list(env, argv[1], pointer_vec)) {
return exla::nif::error(env, "Unable to get device pointer.");
}
if (!exla::nif::get_atom(env, argv[2], pointer_kind)) {
if (!exla::nif::get_atom(env, argv[1], pointer_kind)) {
return exla::nif::error(env, "Unable to get device pointer kind.");
}

if (pointer_kind == "cuda_ipc") {
if (!enif_inspect_binary(env, argv[2], &cuda_ipc_handle_bin)) {
return exla::nif::error(env, "Unable to get CUDA IPC handle.");
}
} else if (pointer_kind == "host_ipc") {
const ERL_NIF_TERM* tuple;
int arity;
if (
!enif_get_tuple(env, argv[2], &arity, &tuple) ||
(arity != 2) ||
!exla::nif::get(env, tuple[0], &fd) ||
(fd == -1) ||
!exla::nif::get(env, tuple[1], memname)) {
return exla::nif::error(env, "Unable to get IPC handle.");
}
} else if (pointer_kind == "local") {
int64_t ptr_int;
if (!exla::nif::get(env, argv[2], &ptr_int)) {
return exla::nif::error(env, "Unable to get pointer.");
}

ptr = (void*)ptr_int;
}

if (!exla::nif::get_typespec_as_xla_shape(env, argv[3], &shape)) {
return exla::nif::error(env, "Unable to get shape.");
}
if (!exla::nif::get(env, argv[4], &device_id)) {
return exla::nif::error(env, "Unable to get device ordinal.");
}

void* ptr;
if (pointer_kind == "local") {
if (pointer_vec.size() != sizeof(void*)) {
// This helps prevent segfaults if someone passes an IPC handle instead of
// a local pointer.
return exla::nif::error(env, "Invalid pointer size for selected mode.");
}
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
for (size_t i = 0; i < sizeof(void*); i++) {
bytePtr[i] = pointer_vec[i];
std::function<void()> on_delete_callback = []() {};

if (pointer_kind == "host_ipc") {
size_t device_size = (size_t)xla::ShapeUtil::ByteSizeOf(shape);

ptr = open_ipc_handle(fd, device_size);
if (ptr == nullptr) {
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
}

on_delete_callback = [fd, memname, ptr, device_size]() {
close_ipc_handle(fd, ptr, (char*)memname.c_str(), device_size);
};
} else if (pointer_kind == "cuda_ipc") {
auto result = get_pointer_for_ipc_handle(pointer_vec, device_id);
auto result = get_pointer_for_ipc_handle(cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id);
if (result.second) {
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
}
Expand All @@ -527,8 +583,8 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E

EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env);

std::function<void()> on_delete_callback = []() {};
EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, shape, device, on_delete_callback), env);

exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer(std::move(buffer));
return exla::nif::ok(env, exla::nif::make<exla::ExlaBuffer*>(env, exla_buffer));
}
Expand Down
9 changes: 4 additions & 5 deletions exla/c_src/exla/exla_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t pt
return std::make_pair(result, status != cudaSuccess);
}

std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list, int device_id) {
if (handle_list.size() != sizeof(cudaIpcMemHandle_t)) {
printf("Error: Invalid CUDA IPC memory handle size\n");
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) {
if (handle_size != sizeof(cudaIpcMemHandle_t)) {
return std::make_pair(nullptr, 1); // Return with error status
}

unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)];
for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {
ipc_handle_data[i] = (uint8_t)handle_list[i];
ipc_handle_data[i] = handle_bin[i];
}

cudaIpcMemHandle_t ipc_handle;
Expand All @@ -54,7 +53,7 @@ std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t pt
return std::make_pair(std::vector<unsigned char>(0), 1);
}

std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list, int device_id) {
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) {
return std::make_pair(nullptr, 1);
}
#endif
3 changes: 2 additions & 1 deletion exla/c_src/exla/exla_cuda.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include <cstddef>
#include <cstdint>
#include <vector>

std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t);
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t>, int);
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t*, size_t, int);
47 changes: 47 additions & 0 deletions exla/c_src/exla/ipc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "ipc.h"

#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

#include <iostream>

// Function to create or open a shared memory object and set its size
int get_ipc_handle(const char* memname, size_t memsize) {
int fd = shm_open(memname, O_CREAT | O_RDWR, 0666);
if (fd == -1) {
return -1;
}

if (ftruncate(fd, memsize) == -1) {
close(fd);
return -1;
}

return fd;
}

// Function to map the shared memory in this process
void* open_ipc_handle(int fd, size_t memsize) {
void* ptr = mmap(NULL, memsize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if (ptr == MAP_FAILED) {
perror("mmap");
return nullptr;
}
return ptr;
}

int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize) {
if (munmap(ptr, memsize) == -1) {
return -1;
}

if (close(fd) == -1) {
return -1;
}

shm_unlink(memname);

return 0;
}
7 changes: 7 additions & 0 deletions exla/c_src/exla/ipc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include <cstddef>

int get_ipc_handle(const char* memname, size_t memsize);
void* open_ipc_handle(int fd, size_t memsize);
int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize);
Loading

0 comments on commit 93e4383

Please sign in to comment.