Skip to content

Commit d15582b

Browse files
VoVAllenBarclayII
authored andcommitted
Remove redundant fill in SPMM kernel (#3166)
* remove redundant fill * trigger ci
1 parent 5f5a6ef commit d15582b

File tree

1 file changed

+0
-16
lines changed

1 file changed

+0
-16
lines changed

src/array/cpu/spmm.cc

-16
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
2222
if (reduce == "sum") {
2323
SWITCH_BITS(bits, DType, {
2424
SWITCH_OP(op, Op, {
25-
DType *out_off = out.Ptr<DType>();
26-
std::fill(out_off, out_off + csr.num_rows * dim, 0);
2725
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
2826
});
2927
});
@@ -33,8 +31,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
3331
DType *out_off = out.Ptr<DType>();
3432
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
3533
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
36-
if (Op::use_lhs) std::fill(argX, argX + csr.num_rows * dim, 0);
37-
if (Op::use_rhs) std::fill(argW, argW + csr.num_rows * dim, 0);
3834
if (reduce == "max") {
3935
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
4036
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
@@ -66,11 +62,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
6662
if (reduce == "sum") {
6763
SWITCH_BITS(bits, DType, {
6864
SWITCH_OP(op, Op, {
69-
// TODO(Israt): Ideally the for loop should go over num_ntypes
70-
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
71-
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
72-
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, 0);
73-
}
7465
/* Call SpMM for each relation type */
7566
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
7667
const dgl_type_t src_id = ufeat_node_tids[etype];
@@ -86,13 +77,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
8677
} else if (reduce == "max" || reduce == "min") {
8778
SWITCH_BITS(bits, DType, {
8879
SWITCH_OP(op, Op, {
89-
// TODO(Israt): Ideally the for loop should go over num_ntypes
90-
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
91-
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
92-
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
93-
if (Op::use_lhs) std::fill(argX, argX + vec_csr[etype].num_rows * dim, 0);
94-
if (Op::use_rhs) std::fill(argW, argW + vec_csr[etype].num_rows * dim, 0);
95-
}
9680
/* Call SpMM for each relation type */
9781
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
9882
const dgl_type_t src_id = ufeat_node_tids[etype];

0 commit comments

Comments
 (0)