|
| 1 | +/*! |
| 2 | + * Copyright (c) 2022 by Contributors |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * |
| 16 | + * \file gpu_cache.cu |
| 17 | + * \brief Implementation of wrapper HugeCTR gpu_cache routines. |
| 18 | + */ |
| 19 | + |
| 20 | +#ifndef DGL_RUNTIME_CUDA_GPU_CACHE_H_ |
| 21 | +#define DGL_RUNTIME_CUDA_GPU_CACHE_H_ |
| 22 | + |
| 23 | +#include <cuda_runtime.h> |
| 24 | +#include <dgl/array.h> |
| 25 | +#include <dgl/aten/array_ops.h> |
| 26 | +#include <dgl/packed_func_ext.h> |
| 27 | +#include <dgl/runtime/container.h> |
| 28 | +#include <dgl/runtime/device_api.h> |
| 29 | +#include <dgl/runtime/object.h> |
| 30 | +#include <dgl/runtime/registry.h> |
| 31 | + |
| 32 | +#include <nv_gpu_cache.hpp> |
| 33 | + |
| 34 | +#include "../../runtime/cuda/cuda_common.h" |
| 35 | + |
| 36 | +namespace dgl { |
| 37 | +namespace runtime { |
| 38 | +namespace cuda { |
| 39 | + |
| 40 | +template <typename key_t> |
| 41 | +class GpuCache : public runtime::Object { |
| 42 | + constexpr static int set_associativity = 2; |
| 43 | + constexpr static int WARP_SIZE = 32; |
| 44 | + constexpr static int bucket_size = WARP_SIZE * set_associativity; |
| 45 | + using gpu_cache_t = gpu_cache::gpu_cache< |
| 46 | + key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity, |
| 47 | + WARP_SIZE>; |
| 48 | + |
| 49 | + public: |
| 50 | + static constexpr const char *_type_key = |
| 51 | + sizeof(key_t) == 4 ? "cuda.GpuCache32" : "cuda.GpuCache64"; |
| 52 | + DGL_DECLARE_OBJECT_TYPE_INFO(GpuCache, Object); |
| 53 | + |
| 54 | + GpuCache(size_t num_items, size_t num_feats) |
| 55 | + : num_feats(num_feats), |
| 56 | + cache(std::make_unique<gpu_cache_t>( |
| 57 | + (num_items + bucket_size - 1) / bucket_size, num_feats)) { |
| 58 | + CUDA_CALL(cudaGetDevice(&cuda_device)); |
| 59 | + } |
| 60 | + |
| 61 | + std::tuple<NDArray, IdArray, IdArray> Query(IdArray keys) { |
| 62 | + const auto &ctx = keys->ctx; |
| 63 | + cudaStream_t stream = dgl::runtime::getCurrentCUDAStream(); |
| 64 | + auto device = dgl::runtime::DeviceAPI::Get(ctx); |
| 65 | + CHECK_EQ(ctx.device_type, kDGLCUDA) |
| 66 | + << "The keys should be on a CUDA device"; |
| 67 | + CHECK_EQ(ctx.device_id, cuda_device) |
| 68 | + << "The keys should be on the correct CUDA device"; |
| 69 | + CHECK_EQ(keys->ndim, 1) |
| 70 | + << "The tensor of requested indices must be of dimension one."; |
| 71 | + NDArray values = NDArray::Empty( |
| 72 | + {keys->shape[0], (int64_t)num_feats}, DGLDataType{kDGLFloat, 32, 1}, |
| 73 | + ctx); |
| 74 | + IdArray missing_index = aten::NewIdArray(keys->shape[0], ctx, 64); |
| 75 | + IdArray missing_keys = |
| 76 | + aten::NewIdArray(keys->shape[0], ctx, sizeof(key_t) * 8); |
| 77 | + size_t *missing_len = |
| 78 | + static_cast<size_t *>(device->AllocWorkspace(ctx, sizeof(size_t))); |
| 79 | + cache->Query( |
| 80 | + static_cast<const key_t *>(keys->data), keys->shape[0], |
| 81 | + static_cast<float *>(values->data), |
| 82 | + static_cast<uint64_t *>(missing_index->data), |
| 83 | + static_cast<key_t *>(missing_keys->data), missing_len, stream); |
| 84 | + size_t missing_len_host; |
| 85 | + device->CopyDataFromTo( |
| 86 | + missing_len, 0, &missing_len_host, 0, sizeof(missing_len_host), ctx, |
| 87 | + DGLContext{kDGLCPU, 0}, keys->dtype); |
| 88 | + device->FreeWorkspace(ctx, missing_len); |
| 89 | + missing_index = missing_index.CreateView( |
| 90 | + {(int64_t)missing_len_host}, missing_index->dtype); |
| 91 | + missing_keys = |
| 92 | + missing_keys.CreateView({(int64_t)missing_len_host}, keys->dtype); |
| 93 | + return std::make_tuple(values, missing_index, missing_keys); |
| 94 | + } |
| 95 | + |
| 96 | + void Replace(IdArray keys, NDArray values) { |
| 97 | + cudaStream_t stream = dgl::runtime::getCurrentCUDAStream(); |
| 98 | + CHECK_EQ(keys->ctx.device_type, kDGLCUDA) |
| 99 | + << "The keys should be on a CUDA device"; |
| 100 | + CHECK_EQ(keys->ctx.device_id, cuda_device) |
| 101 | + << "The keys should be on the correct CUDA device"; |
| 102 | + CHECK_EQ(values->ctx.device_type, kDGLCUDA) |
| 103 | + << "The values should be on a CUDA device"; |
| 104 | + CHECK_EQ(values->ctx.device_id, cuda_device) |
| 105 | + << "The values should be on the correct CUDA device"; |
| 106 | + CHECK_EQ(keys->shape[0], values->shape[0]) |
| 107 | + << "First dimensions of keys and values must match"; |
| 108 | + CHECK_EQ(values->shape[1], num_feats) << "Embedding dimension must match"; |
| 109 | + cache->Replace( |
| 110 | + static_cast<const key_t *>(keys->data), keys->shape[0], |
| 111 | + static_cast<const float *>(values->data), stream); |
| 112 | + } |
| 113 | + |
| 114 | + private: |
| 115 | + size_t num_feats; |
| 116 | + std::unique_ptr<gpu_cache_t> cache; |
| 117 | + int cuda_device; |
| 118 | +}; |
| 119 | + |
| 120 | +static_assert(sizeof(unsigned int) == 4); |
| 121 | +DGL_DEFINE_OBJECT_REF(GpuCacheRef32, GpuCache<unsigned int>); |
| 122 | +// The cu file in HugeCTR gpu cache uses unsigned int and long long. |
| 123 | +// Changing to int64_t results in a mismatch of template arguments. |
| 124 | +static_assert(sizeof(long long) == 8); // NOLINT |
| 125 | +DGL_DEFINE_OBJECT_REF(GpuCacheRef64, GpuCache<long long>); // NOLINT |
| 126 | + |
| 127 | +/* CAPI **********************************************************************/ |
| 128 | + |
| 129 | +using namespace dgl::runtime; |
| 130 | + |
| 131 | +DGL_REGISTER_GLOBAL("cuda._CAPI_DGLGpuCacheCreate") |
| 132 | + .set_body([](DGLArgs args, DGLRetValue *rv) { |
| 133 | + const size_t num_items = args[0]; |
| 134 | + const size_t num_feats = args[1]; |
| 135 | + const int num_bits = args[2]; |
| 136 | + |
| 137 | + if (num_bits == 32) |
| 138 | + *rv = GpuCacheRef32( |
| 139 | + std::make_shared<GpuCache<unsigned int>>(num_items, num_feats)); |
| 140 | + else |
| 141 | + *rv = GpuCacheRef64(std::make_shared<GpuCache<long long>>( // NOLINT |
| 142 | + num_items, num_feats)); |
| 143 | + }); |
| 144 | + |
| 145 | +DGL_REGISTER_GLOBAL("cuda._CAPI_DGLGpuCacheQuery") |
| 146 | + .set_body([](DGLArgs args, DGLRetValue *rv) { |
| 147 | + IdArray keys = args[1]; |
| 148 | + |
| 149 | + List<ObjectRef> ret; |
| 150 | + if (keys->dtype.bits == 32) { |
| 151 | + GpuCacheRef32 cache = args[0]; |
| 152 | + auto result = cache->Query(keys); |
| 153 | + |
| 154 | + ret.push_back(Value(MakeValue(std::get<0>(result)))); |
| 155 | + ret.push_back(Value(MakeValue(std::get<1>(result)))); |
| 156 | + ret.push_back(Value(MakeValue(std::get<2>(result)))); |
| 157 | + } else { |
| 158 | + GpuCacheRef64 cache = args[0]; |
| 159 | + auto result = cache->Query(keys); |
| 160 | + |
| 161 | + ret.push_back(Value(MakeValue(std::get<0>(result)))); |
| 162 | + ret.push_back(Value(MakeValue(std::get<1>(result)))); |
| 163 | + ret.push_back(Value(MakeValue(std::get<2>(result)))); |
| 164 | + } |
| 165 | + |
| 166 | + *rv = ret; |
| 167 | + }); |
| 168 | + |
| 169 | +DGL_REGISTER_GLOBAL("cuda._CAPI_DGLGpuCacheReplace") |
| 170 | + .set_body([](DGLArgs args, DGLRetValue *rv) { |
| 171 | + IdArray keys = args[1]; |
| 172 | + NDArray values = args[2]; |
| 173 | + |
| 174 | + if (keys->dtype.bits == 32) { |
| 175 | + GpuCacheRef32 cache = args[0]; |
| 176 | + cache->Replace(keys, values); |
| 177 | + } else { |
| 178 | + GpuCacheRef64 cache = args[0]; |
| 179 | + cache->Replace(keys, values); |
| 180 | + } |
| 181 | + |
| 182 | + *rv = List<ObjectRef>{}; |
| 183 | + }); |
| 184 | + |
| 185 | +} // namespace cuda |
| 186 | +} // namespace runtime |
| 187 | +} // namespace dgl |
| 188 | + |
| 189 | +#endif |
0 commit comments