@@ -22,8 +22,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
22
22
if (reduce == " sum" ) {
23
23
SWITCH_BITS (bits, DType, {
24
24
SWITCH_OP (op, Op, {
25
- DType *out_off = out.Ptr <DType>();
26
- std::fill (out_off, out_off + csr.num_rows * dim, 0 );
27
25
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
28
26
});
29
27
});
@@ -33,8 +31,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
33
31
DType *out_off = out.Ptr <DType>();
34
32
IdType* argX = Op::use_lhs ? static_cast <IdType*>(out_aux[0 ]->data ) : nullptr ;
35
33
IdType* argW = Op::use_rhs ? static_cast <IdType*>(out_aux[1 ]->data ) : nullptr ;
36
- if (Op::use_lhs) std::fill (argX, argX + csr.num_rows * dim, 0 );
37
- if (Op::use_rhs) std::fill (argW, argW + csr.num_rows * dim, 0 );
38
34
if (reduce == " max" ) {
39
35
std::fill (out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
40
36
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
@@ -66,11 +62,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
66
62
if (reduce == " sum" ) {
67
63
SWITCH_BITS (bits, DType, {
68
64
SWITCH_OP (op, Op, {
69
- // TODO(Israt): Ideally the for loop should go over num_ntypes
70
- for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
71
- DType *out_off = vec_out[out_node_tids[etype]].Ptr <DType>();
72
- std::fill (out_off, out_off + vec_csr[etype].num_rows * dim, 0 );
73
- }
74
65
/* Call SpMM for each relation type */
75
66
for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
76
67
const dgl_type_t src_id = ufeat_node_tids[etype];
@@ -86,13 +77,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
86
77
} else if (reduce == " max" || reduce == " min" ) {
87
78
SWITCH_BITS (bits, DType, {
88
79
SWITCH_OP (op, Op, {
89
- // TODO(Israt): Ideally the for loop should go over num_ntypes
90
- for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
91
- IdType* argX = Op::use_lhs ? static_cast <IdType*>(out_aux[0 ]->data ) : nullptr ;
92
- IdType* argW = Op::use_rhs ? static_cast <IdType*>(out_aux[1 ]->data ) : nullptr ;
93
- if (Op::use_lhs) std::fill (argX, argX + vec_csr[etype].num_rows * dim, 0 );
94
- if (Op::use_rhs) std::fill (argW, argW + vec_csr[etype].num_rows * dim, 0 );
95
- }
96
80
/* Call SpMM for each relation type */
97
81
for (dgl_type_t etype = 0 ; etype < ufeat_node_tids.size (); ++etype) {
98
82
const dgl_type_t src_id = ufeat_node_tids[etype];
0 commit comments