Skip to content

Commit 9b2e882

Browse files
authored
Implemented sparse2sparse transformation (COO to CSR using cuSPARSE) (#918)
1 parent 9207c50 commit 9b2e882

File tree

5 files changed

+350
-2
lines changed

5 files changed

+350
-2
lines changed

docs_input/basics/sparse_tensor.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ correct way of performing the conversion above is as follows::
8787
(A = sparse2dense(Acoo)).run(exec);
8888

8989
The current experimental sparse support in MatX provides efficient
90-
operations for sparse-to-dense, dense-to-sparse, matvec, matmul,
91-
and solve::
90+
operations for sparse-to-dense, dense-to-sparse, sparse-to-sparse,
91+
matvec, matmul, and solve::
9292

9393
(A = sparse2dense(Acoo)).run(exec);
9494
(Acoo = dense2sparse(D)).run(exec);
95+
(Acsr = sparse2sparse(Acoo)).run(exec);
9596
(V = matvec(Acoo, W)).run(exec); // only Sparse-Matrix x Vector (SpMV)
9697
(C = matmul(Acoo, B)).run(exec); // only Sparse-Matrix x Matrix (SpMM)
9798
(X = solve(Acsr, Y)).run(exec); // only on CSR format

examples/sparse_tensor.cu

+9
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,14 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
166166
(Acoo = dense2sparse(D)).run(exec);
167167
print(Acoo);
168168

169+
//
170+
// Conversions between sparse formats: COO to CSR.
171+
// For speed-of-operation, the CSC output actually
172+
// shares some of the buffers with COO on completion.
173+
//
174+
auto Acsr2 = experimental::make_zero_tensor_csr<float, int, int>({4, 8});
175+
(Acsr2 = sparse2sparse(Acoo)).run(exec);
176+
print(Acsr2);
177+
169178
MATX_EXIT_HANDLER();
170179
}

include/matx/operators/operators.h

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
#include "matx/operators/sign.h"
102102
#include "matx/operators/slice.h"
103103
#include "matx/operators/sparse2dense.h"
104+
#include "matx/operators/sparse2sparse.h"
104105
#include "matx/operators/solve.h"
105106
#include "matx/operators/sort.h"
106107
#include "matx/operators/sph2cart.h"
+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2025, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice,
11+
// this list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31+
// POSSIBILITY OF SUCH DAMAGE.
32+
/////////////////////////////////////////////////////////////////////////////////
33+
34+
#pragma once
35+
36+
#include "matx/core/type_utils.h"
37+
#include "matx/operators/base_operator.h"
38+
#include "matx/transforms/convert/sparse2sparse_cusparse.h"
39+
40+
namespace matx {
41+
namespace detail {
42+
43+
template <typename OpA>
44+
class Sparse2SparseOp : public BaseOp<Sparse2SparseOp<OpA>> {
45+
private:
46+
typename detail::base_type_t<OpA> a_;
47+
48+
public:
49+
using matxop = bool;
50+
using matx_transform_op = bool;
51+
using tosparse_xform_op = bool;
52+
using value_type = typename OpA::value_type;
53+
54+
__MATX_INLINE__ Sparse2SparseOp(const OpA &a) : a_(a) {}
55+
56+
__MATX_INLINE__ std::string str() const {
57+
return "sparse2sparse(" + get_type_str(a_) + ")";
58+
}
59+
60+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t
61+
Rank() {
62+
return remove_cvref_t<OpA>::Rank();
63+
}
64+
65+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t
66+
Size(int dim) const {
67+
return a_.Size(dim);
68+
}
69+
70+
template <typename Out, typename Executor>
71+
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
72+
if constexpr (is_sparse_tensor_v<OpA> && is_sparse_tensor_v<Out>) {
73+
// NOTE: sparse assignment O = sparse2sparse(A) takes direct reference!
74+
sparse2sparse_impl(out, a_, ex);
75+
} else {
76+
MATX_THROW(matxNotSupported, "Cannot use sparse2sparse on dense operands");
77+
}
78+
}
79+
};
80+
81+
} // end namespace detail
82+
83+
/**
84+
* Convert a sparse tensor into a sparse tensor. Typically
85+
* used to convert storage format (e.g. COO to CSR). Note that
86+
* for speed-of-operation, after this operation, the input and
87+
* output tensor may share some of the underlying allocated
88+
* memory (e.g. the values array).
89+
*
90+
* Currently only COO to CSR is supported (with CSR "stealing"
91+
* the j-index and values array from COO, while recomputing its
92+
* own positions array).
93+
*
94+
* @tparam OpA
95+
* Data type of A tensor
96+
*
97+
* @param A
98+
* Sparse input tensor
99+
*
100+
* @return
101+
* Sparse output tensor
102+
*/
103+
template <typename OpA> __MATX_INLINE__ auto sparse2sparse(const OpA &A) {
104+
return detail::Sparse2SparseOp(A);
105+
}
106+
107+
} // end namespace matx
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2025, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice,
11+
// this list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31+
// POSSIBILITY OF SUCH DAMAGE.
32+
/////////////////////////////////////////////////////////////////////////////////
33+
34+
#pragma once
35+
36+
#include <cusparse.h>
37+
38+
#include <numeric>
39+
40+
#include "matx/core/cache.h"
41+
#include "matx/core/sparse_tensor.h"
42+
#include "matx/core/tensor.h"
43+
44+
namespace matx {
45+
46+
namespace detail {
47+
48+
/**
49+
* Parameters needed to execute a cuSPARSE sparse2sparse.
50+
*/
51+
struct Sparse2SparseParams_t {
52+
MatXDataType_t dtype;
53+
MatXDataType_t ptype;
54+
MatXDataType_t ctype;
55+
cudaStream_t stream;
56+
index_t nse;
57+
index_t m;
58+
index_t n;
59+
// Matrix handles in cuSPARSE are data specific (unlike e.g. cuBLAS
60+
// where the same plan can be shared between different data buffers).
61+
void *ptrO1;
62+
void *ptrA0;
63+
void *ptrA1;
64+
void *ptrA2;
65+
void *ptrA3;
66+
};
67+
68+
// Helper method to wrap pointer/size in new storage.
69+
template <typename T>
70+
__MATX_INLINE__ static auto wrapDefaultNonOwningStorage(T *ptr, size_t sz) {
71+
raw_pointer_buffer<T, matx_allocator<T>> buf{ptr, sz * sizeof(T),
72+
/*owning=*/false};
73+
return basic_storage<decltype(buf)>{std::move(buf)};
74+
}
75+
76+
template <typename TensorTypeO, typename TensorTypeA>
77+
class Sparse2SparseHandle_t {
78+
public:
79+
using TA = typename TensorTypeA::value_type;
80+
using TO = typename TensorTypeO::value_type;
81+
82+
using VAL = typename TensorTypeO::val_type;
83+
using POS = typename TensorTypeO::pos_type;
84+
using CRD = typename TensorTypeO::crd_type;
85+
86+
/**
87+
* Construct a sparse2sparse handle.
88+
*/
89+
Sparse2SparseHandle_t(TensorTypeO &o, const TensorTypeA &a,
90+
cudaStream_t stream) {
91+
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
92+
params_ = GetConvParams(o, a, stream);
93+
94+
[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
95+
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError);
96+
97+
static_assert(is_sparse_tensor_v<TensorTypeA>);
98+
static_assert(is_sparse_tensor_v<TensorTypeO>);
99+
100+
if constexpr (TensorTypeA::Format::isCOO() &&
101+
TensorTypeO::Format::isCSR()) {
102+
// For speed-of-operation, CSR output shamelessly
103+
// steals the values and j-index buffers from COO.
104+
VAL *val = reinterpret_cast<VAL *>(params_.ptrA0);
105+
CRD *crd = reinterpret_cast<CRD *>(params_.ptrA3);
106+
o.SetVal(wrapDefaultNonOwningStorage<VAL>(val, params_.nse));
107+
o.SetCrd(1, wrapDefaultNonOwningStorage<CRD>(crd, params_.nse));
108+
o.SetSparseDataImpl();
109+
} else {
110+
MATX_THROW(matxNotSupported,
111+
"Sparse2Sparse currently only supports COO2CSR");
112+
}
113+
}
114+
115+
~Sparse2SparseHandle_t() { cusparseDestroy(handle_); }
116+
117+
static detail::Sparse2SparseParams_t
118+
GetConvParams(TensorTypeO &o, const TensorTypeA &a, cudaStream_t stream) {
119+
detail::Sparse2SparseParams_t params;
120+
params.dtype = TypeToInt<VAL>();
121+
params.ptype = TypeToInt<POS>();
122+
params.ctype = TypeToInt<CRD>();
123+
params.stream = stream;
124+
// TODO: simple no-batch, row-wise, no-transpose for now
125+
params.nse = a.Nse();
126+
params.m = a.Size(TensorTypeA::Rank() - 2);
127+
params.n = a.Size(TensorTypeA::Rank() - 1);
128+
// Matrix handles in cuSPARSE are data specific. Therefore, the pointers to
129+
// the underlying buffers are part of the conversion parameters. In this
130+
// case, only the position pointers uniquely determine the sparse output,
131+
// since the value and coordinate data will be re-allocated on execution.
132+
params.ptrO1 = o.POSData(1);
133+
params.ptrA0 = a.Data();
134+
params.ptrA1 = a.POSData(0);
135+
params.ptrA2 = a.CRDData(0);
136+
params.ptrA3 = a.CRDData(1);
137+
return params;
138+
}
139+
140+
__MATX_INLINE__ void Exec([[maybe_unused]] TensorTypeO &o,
141+
[[maybe_unused]] const TensorTypeA &a) {
142+
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL);
143+
const cusparseIndexBase_t base = CUSPARSE_INDEX_BASE_ZERO;
144+
// Legacy API takes specific types only.
145+
CRD *crd = reinterpret_cast<CRD *>(params_.ptrA2);
146+
POS *pos = reinterpret_cast<POS *>(params_.ptrO1);
147+
const int nse = static_cast<int>(params_.nse);
148+
const int m = static_cast<int>(params_.m);
149+
[[maybe_unused]] cusparseStatus_t ret =
150+
cusparseXcoo2csr(handle_, crd, nse, m, pos, base);
151+
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError);
152+
}
153+
154+
private:
155+
cusparseHandle_t handle_ = nullptr; // TODO: share handle globally?
156+
detail::Sparse2SparseParams_t params_;
157+
};
158+
159+
/**
160+
* Crude hash on Sparse2Sparse to get a reasonably good delta for collisions.
161+
* This doesn't need to be perfect, but fast enough to not slow down lookups,
162+
* and different enough so the common conversion parameters change.
163+
*/
164+
struct Sparse2SparseParamsKeyHash {
165+
std::size_t operator()(const Sparse2SparseParams_t &k) const noexcept {
166+
return std::hash<uint64_t>()(reinterpret_cast<uint64_t>(k.ptrO1)) +
167+
std::hash<uint64_t>()(reinterpret_cast<uint64_t>(k.ptrA0)) +
168+
std::hash<uint64_t>()(reinterpret_cast<uint64_t>(k.stream));
169+
}
170+
};
171+
172+
/**
173+
* Test Sparse2Sparse parameters for equality. Unlike the hash, all parameters
174+
* must match exactly to ensure the hashed kernel can be reused for the
175+
* computation.
176+
*/
177+
struct Sparse2SparseParamsKeyEq {
178+
bool operator()(const Sparse2SparseParams_t &l,
179+
const Sparse2SparseParams_t &t) const noexcept {
180+
return l.dtype == t.dtype && l.ptype == t.ptype && l.ctype == t.ctype &&
181+
l.stream == t.stream && l.nse == t.nse && l.m == t.m && l.n == t.n &&
182+
l.ptrO1 == t.ptrO1 && l.ptrA0 == t.ptrA0 && l.ptrA1 == t.ptrA1 &&
183+
l.ptrA2 == t.ptrA2 && l.ptrA3 == t.ptrA3;
184+
}
185+
};
186+
187+
using sparse2sparse_cache_t =
188+
std::unordered_map<Sparse2SparseParams_t, std::any,
189+
Sparse2SparseParamsKeyHash, Sparse2SparseParamsKeyEq>;
190+
191+
} // end namespace detail
192+
193+
template <typename OutputTensorType, typename InputTensorType>
194+
void sparse2sparse_impl(OutputTensorType &o, const InputTensorType &a,
195+
const cudaExecutor &exec) {
196+
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
197+
const auto stream = exec.getStream();
198+
199+
using atype = InputTensorType;
200+
using otype = OutputTensorType;
201+
202+
using TA = typename atype::value_type;
203+
using TO = typename otype::value_type;
204+
205+
static constexpr int RANKA = atype::Rank();
206+
static constexpr int RANKO = otype::Rank();
207+
208+
// Restrictions.
209+
static_assert(RANKA == 2 && RANKO == 2, "tensors must have rank-2");
210+
static_assert(std::is_same_v<TA, TO>, "tensors must have the same data type");
211+
static_assert(std::is_same_v<typename atype::crd_type, int32_t> &&
212+
std::is_same_v<typename otype::pos_type, int32_t> &&
213+
std::is_same_v<typename otype::crd_type, int32_t>,
214+
"unsupported index type");
215+
216+
// Get parameters required by these tensors (for caching).
217+
auto params =
218+
detail::Sparse2SparseHandle_t<otype, atype>::GetConvParams(o, a, stream);
219+
220+
// Lookup and cache.
221+
using cache_val_type = detail::Sparse2SparseHandle_t<otype, atype>;
222+
detail::GetCache().LookupAndExec<detail::sparse2sparse_cache_t>(
223+
detail::GetCacheIdFromType<detail::sparse2sparse_cache_t>(), params,
224+
[&]() { return std::make_shared<cache_val_type>(o, a, stream); },
225+
[&](std::shared_ptr<cache_val_type> cache_type) {
226+
cache_type->Exec(o, a);
227+
});
228+
}
229+
230+
} // end namespace matx

0 commit comments

Comments
 (0)