Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a heap for small sizes #1911

Merged
merged 2 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions mlx/backend/metal/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

namespace mlx::core {

constexpr size_t resource_options =
MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked;

namespace allocator {

Allocator& allocator() {
Expand Down Expand Up @@ -150,15 +153,34 @@ MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
std::get<size_t>(device_info().at("max_recommended_working_set_size"));
resource_limit_ = std::get<size_t>(device_info().at("resource_limit"));
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
bool is_vm = std::get<std::string>(device_info().at("device_name")) ==
"Apple Paravirtual device";
if (is_vm) {
return;
}
auto heap_desc = MTL::HeapDescriptor::alloc()->init();
heap_desc->setResourceOptions(resource_options);
heap_desc->setSize(heap_size_);
heap_ = device_->newHeap(heap_desc);
heap_desc->release();
residency_set_.insert(heap_);
}

MetalAllocator::~MetalAllocator() {
auto pool = metal::new_scoped_memory_pool();
if (heap_) {
heap_->release();
}
}

size_t MetalAllocator::set_cache_limit(size_t limit) {
Expand Down Expand Up @@ -226,16 +248,19 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
}

// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
if (num_resources_ >= resource_limit_) {
std::ostringstream msg;
msg << "[metal::malloc] Resource limit (" << resource_limit_
<< ") exceeded.";
throw std::runtime_error(msg.str());
}
lk.unlock();
buf = device_->newBuffer(size, res_opt);
if (size < small_size_ && heap_) {
buf = heap_->newBuffer(size, resource_options);
}
if (!buf) {
buf = device_->newBuffer(size, resource_options);
}
lk.lock();
if (buf) {
num_resources_++;
Expand All @@ -246,13 +271,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
peak_memory_ = std::max(peak_memory_, active_memory_);

// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
if (get_cache_memory() > max_pool_size_) {
auto pool = metal::new_scoped_memory_pool();
num_resources_ -= buffer_cache_.release_cached_buffers(
get_cache_memory() - max_pool_size_);
}

residency_set_.insert(buf);
if (!buf->heap()) {
residency_set_.insert(buf);
}

return Buffer{static_cast<void*>(buf)};
}
Expand All @@ -269,7 +296,9 @@ void MetalAllocator::free(Buffer buffer) {
return;
}
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
Expand Down Expand Up @@ -301,7 +330,7 @@ size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
Expand Down
9 changes: 9 additions & 0 deletions mlx/backend/metal/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BufferCache {
void remove_from_list(BufferHolder* to_remove);

MTL::Device* device_;
MTL::Heap* heap_{nullptr};

std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
Expand Down Expand Up @@ -78,7 +79,15 @@ class MetalAllocator : public allocator::Allocator {

private:
MTL::Device* device_;

// The size of allocations which go on the heap until it is full. This size
// is chosen because it is the actual minimum size of a buffer allocated from
// the heap, a heap can have at most heap.size() / 256 buffers.
static constexpr int small_size_ = 256;
static constexpr int heap_size_ = 1 << 20;
MTL::Heap* heap_;
MetalAllocator();
~MetalAllocator();
friend MetalAllocator& allocator();

// Caching allocator
Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,12 +692,13 @@ void new_stream(Stream stream) {
}
}

std::unordered_map<std::string, std::variant<std::string, size_t>>
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());

size_t memsize = 0;
Expand All @@ -711,6 +712,7 @@ device_info() {
}

return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/metal/metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void start_capture(std::string path = "");
void stop_capture();

/** Get information about the GPU and system settings. */
std::unordered_map<std::string, std::variant<std::string, size_t>>
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info();

} // namespace mlx::core::metal
2 changes: 1 addition & 1 deletion mlx/backend/no_metal/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void start_capture(std::string) {}
void stop_capture() {}
void clear_cache() {}

std::unordered_map<std::string, std::variant<std::string, size_t>>
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");
Expand Down