|
| 1 | +//////////////////////////////////////////////////////////////////////////////// |
| 2 | +// BSD 3-Clause License |
| 3 | +// |
| 4 | +// Copyright (c) 2025, NVIDIA Corporation |
| 5 | +// All rights reserved. |
| 6 | +// |
| 7 | +// Redistribution and use in source and binary forms, with or without |
| 8 | +// modification, are permitted provided that the following conditions are met: |
| 9 | +// |
| 10 | +// 1. Redistributions of source code must retain the above copyright notice, |
| 11 | +// this list of conditions and the following disclaimer. |
| 12 | +// |
| 13 | +// 2. Redistributions in binary form must reproduce the above copyright notice, |
| 14 | +// this list of conditions and the following disclaimer in the documentation |
| 15 | +// and/or other materials provided with the distribution. |
| 16 | +// |
| 17 | +// 3. Neither the name of the copyright holder nor the names of its |
| 18 | +// contributors may be used to endorse or promote products derived from |
| 19 | +// this software without specific prior written permission. |
| 20 | +// |
| 21 | +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 22 | +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 23 | +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 24 | +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE |
| 25 | +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 26 | +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 27 | +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 28 | +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 29 | +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 30 | +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| 31 | +// POSSIBILITY OF SUCH DAMAGE. |
| 32 | +///////////////////////////////////////////////////////////////////////////////// |
| 33 | + |
| 34 | +#pragma once |
| 35 | + |
| 36 | +#include <cusparse.h> |
| 37 | + |
| 38 | +#include <numeric> |
| 39 | + |
| 40 | +#include "matx/core/cache.h" |
| 41 | +#include "matx/core/sparse_tensor.h" |
| 42 | +#include "matx/core/tensor.h" |
| 43 | + |
| 44 | +namespace matx { |
| 45 | + |
| 46 | +namespace detail { |
| 47 | + |
| 48 | +/** |
| 49 | + * Parameters needed to execute a cuSPARSE sparse2sparse. |
| 50 | + */ |
| 51 | +struct Sparse2SparseParams_t { |
| 52 | + MatXDataType_t dtype; |
| 53 | + MatXDataType_t ptype; |
| 54 | + MatXDataType_t ctype; |
| 55 | + cudaStream_t stream; |
| 56 | + index_t nse; |
| 57 | + index_t m; |
| 58 | + index_t n; |
| 59 | + // Matrix handles in cuSPARSE are data specific (unlike e.g. cuBLAS |
| 60 | + // where the same plan can be shared between different data buffers). |
| 61 | + void *ptrO1; |
| 62 | + void *ptrA0; |
| 63 | + void *ptrA1; |
| 64 | + void *ptrA2; |
| 65 | + void *ptrA3; |
| 66 | +}; |
| 67 | + |
| 68 | +// Helper method to wrap pointer/size in new storage. |
| 69 | +template <typename T> |
| 70 | +__MATX_INLINE__ static auto wrapDefaultNonOwningStorage(T *ptr, size_t sz) { |
| 71 | + raw_pointer_buffer<T, matx_allocator<T>> buf{ptr, sz * sizeof(T), |
| 72 | + /*owning=*/false}; |
| 73 | + return basic_storage<decltype(buf)>{std::move(buf)}; |
| 74 | +} |
| 75 | + |
| 76 | +template <typename TensorTypeO, typename TensorTypeA> |
| 77 | +class Sparse2SparseHandle_t { |
| 78 | +public: |
| 79 | + using TA = typename TensorTypeA::value_type; |
| 80 | + using TO = typename TensorTypeO::value_type; |
| 81 | + |
| 82 | + using VAL = typename TensorTypeO::val_type; |
| 83 | + using POS = typename TensorTypeO::pos_type; |
| 84 | + using CRD = typename TensorTypeO::crd_type; |
| 85 | + |
| 86 | + /** |
| 87 | + * Construct a sparse2sparse handle. |
| 88 | + */ |
| 89 | + Sparse2SparseHandle_t(TensorTypeO &o, const TensorTypeA &a, |
| 90 | + cudaStream_t stream) { |
| 91 | + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) |
| 92 | + params_ = GetConvParams(o, a, stream); |
| 93 | + |
| 94 | + [[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_); |
| 95 | + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError); |
| 96 | + |
| 97 | + static_assert(is_sparse_tensor_v<TensorTypeA>); |
| 98 | + static_assert(is_sparse_tensor_v<TensorTypeO>); |
| 99 | + |
| 100 | + if constexpr (TensorTypeA::Format::isCOO() && |
| 101 | + TensorTypeO::Format::isCSR()) { |
| 102 | + // For speed-of-operation, CSR output shamelessly |
| 103 | + // steals the values and j-index buffers from COO. |
| 104 | + VAL *val = reinterpret_cast<VAL *>(params_.ptrA0); |
| 105 | + CRD *crd = reinterpret_cast<CRD *>(params_.ptrA3); |
| 106 | + o.SetVal(wrapDefaultNonOwningStorage<VAL>(val, params_.nse)); |
| 107 | + o.SetCrd(1, wrapDefaultNonOwningStorage<CRD>(crd, params_.nse)); |
| 108 | + o.SetSparseDataImpl(); |
| 109 | + } else { |
| 110 | + MATX_THROW(matxNotSupported, |
| 111 | + "Sparse2Sparse currently only supports COO2CSR"); |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + ~Sparse2SparseHandle_t() { cusparseDestroy(handle_); } |
| 116 | + |
| 117 | + static detail::Sparse2SparseParams_t |
| 118 | + GetConvParams(TensorTypeO &o, const TensorTypeA &a, cudaStream_t stream) { |
| 119 | + detail::Sparse2SparseParams_t params; |
| 120 | + params.dtype = TypeToInt<VAL>(); |
| 121 | + params.ptype = TypeToInt<POS>(); |
| 122 | + params.ctype = TypeToInt<CRD>(); |
| 123 | + params.stream = stream; |
| 124 | + // TODO: simple no-batch, row-wise, no-transpose for now |
| 125 | + params.nse = a.Nse(); |
| 126 | + params.m = a.Size(TensorTypeA::Rank() - 2); |
| 127 | + params.n = a.Size(TensorTypeA::Rank() - 1); |
| 128 | + // Matrix handles in cuSPARSE are data specific. Therefore, the pointers to |
| 129 | + // the underlying buffers are part of the conversion parameters. In this |
| 130 | + // case, only the position pointers uniquely determine the sparse output, |
| 131 | + // since the value and coordinate data will be re-allocated on execution. |
| 132 | + params.ptrO1 = o.POSData(1); |
| 133 | + params.ptrA0 = a.Data(); |
| 134 | + params.ptrA1 = a.POSData(0); |
| 135 | + params.ptrA2 = a.CRDData(0); |
| 136 | + params.ptrA3 = a.CRDData(1); |
| 137 | + return params; |
| 138 | + } |
| 139 | + |
| 140 | + __MATX_INLINE__ void Exec([[maybe_unused]] TensorTypeO &o, |
| 141 | + [[maybe_unused]] const TensorTypeA &a) { |
| 142 | + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL); |
| 143 | + const cusparseIndexBase_t base = CUSPARSE_INDEX_BASE_ZERO; |
| 144 | + // Legacy API takes specific types only. |
| 145 | + CRD *crd = reinterpret_cast<CRD *>(params_.ptrA2); |
| 146 | + POS *pos = reinterpret_cast<POS *>(params_.ptrO1); |
| 147 | + const int nse = static_cast<int>(params_.nse); |
| 148 | + const int m = static_cast<int>(params_.m); |
| 149 | + [[maybe_unused]] cusparseStatus_t ret = |
| 150 | + cusparseXcoo2csr(handle_, crd, nse, m, pos, base); |
| 151 | + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError); |
| 152 | + } |
| 153 | + |
| 154 | +private: |
| 155 | + cusparseHandle_t handle_ = nullptr; // TODO: share handle globally? |
| 156 | + detail::Sparse2SparseParams_t params_; |
| 157 | +}; |
| 158 | + |
| 159 | +/** |
| 160 | + * Crude hash on Sparse2Sparse to get a reasonably good delta for collisions. |
| 161 | + * This doesn't need to be perfect, but fast enough to not slow down lookups, |
| 162 | + * and different enough so the common conversion parameters change. |
| 163 | + */ |
| 164 | +struct Sparse2SparseParamsKeyHash { |
| 165 | + std::size_t operator()(const Sparse2SparseParams_t &k) const noexcept { |
| 166 | + return std::hash<uint64_t>()(reinterpret_cast<uint64_t>(k.ptrO1)) + |
| 167 | + std::hash<uint64_t>()(reinterpret_cast<uint64_t>(k.ptrA0)) + |
| 168 | + std::hash<uint64_t>()(reinterpret_cast<uint64_t>(k.stream)); |
| 169 | + } |
| 170 | +}; |
| 171 | + |
| 172 | +/** |
| 173 | + * Test Sparse2Sparse parameters for equality. Unlike the hash, all parameters |
| 174 | + * must match exactly to ensure the hashed kernel can be reused for the |
| 175 | + * computation. |
| 176 | + */ |
| 177 | +struct Sparse2SparseParamsKeyEq { |
| 178 | + bool operator()(const Sparse2SparseParams_t &l, |
| 179 | + const Sparse2SparseParams_t &t) const noexcept { |
| 180 | + return l.dtype == t.dtype && l.ptype == t.ptype && l.ctype == t.ctype && |
| 181 | + l.stream == t.stream && l.nse == t.nse && l.m == t.m && l.n == t.n && |
| 182 | + l.ptrO1 == t.ptrO1 && l.ptrA0 == t.ptrA0 && l.ptrA1 == t.ptrA1 && |
| 183 | + l.ptrA2 == t.ptrA2 && l.ptrA3 == t.ptrA3; |
| 184 | + } |
| 185 | +}; |
| 186 | + |
| 187 | +using sparse2sparse_cache_t = |
| 188 | + std::unordered_map<Sparse2SparseParams_t, std::any, |
| 189 | + Sparse2SparseParamsKeyHash, Sparse2SparseParamsKeyEq>; |
| 190 | + |
| 191 | +} // end namespace detail |
| 192 | + |
| 193 | +template <typename OutputTensorType, typename InputTensorType> |
| 194 | +void sparse2sparse_impl(OutputTensorType &o, const InputTensorType &a, |
| 195 | + const cudaExecutor &exec) { |
| 196 | + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) |
| 197 | + const auto stream = exec.getStream(); |
| 198 | + |
| 199 | + using atype = InputTensorType; |
| 200 | + using otype = OutputTensorType; |
| 201 | + |
| 202 | + using TA = typename atype::value_type; |
| 203 | + using TO = typename otype::value_type; |
| 204 | + |
| 205 | + static constexpr int RANKA = atype::Rank(); |
| 206 | + static constexpr int RANKO = otype::Rank(); |
| 207 | + |
| 208 | + // Restrictions. |
| 209 | + static_assert(RANKA == 2 && RANKO == 2, "tensors must have rank-2"); |
| 210 | + static_assert(std::is_same_v<TA, TO>, "tensors must have the same data type"); |
| 211 | + static_assert(std::is_same_v<typename atype::crd_type, int32_t> && |
| 212 | + std::is_same_v<typename otype::pos_type, int32_t> && |
| 213 | + std::is_same_v<typename otype::crd_type, int32_t>, |
| 214 | + "unsupported index type"); |
| 215 | + |
| 216 | + // Get parameters required by these tensors (for caching). |
| 217 | + auto params = |
| 218 | + detail::Sparse2SparseHandle_t<otype, atype>::GetConvParams(o, a, stream); |
| 219 | + |
| 220 | + // Lookup and cache. |
| 221 | + using cache_val_type = detail::Sparse2SparseHandle_t<otype, atype>; |
| 222 | + detail::GetCache().LookupAndExec<detail::sparse2sparse_cache_t>( |
| 223 | + detail::GetCacheIdFromType<detail::sparse2sparse_cache_t>(), params, |
| 224 | + [&]() { return std::make_shared<cache_val_type>(o, a, stream); }, |
| 225 | + [&](std::shared_ptr<cache_val_type> cache_type) { |
| 226 | + cache_type->Exec(o, a); |
| 227 | + }); |
| 228 | +} |
| 229 | + |
| 230 | +} // end namespace matx |
0 commit comments