Skip to content

Commit 05aebd8

Browse files
tingyu66schmidt-juBarclayIImufeiliUbuntu
authored andcommitted
[Model] Update cugraph-ops models for 23.04 release (#5540)
Co-authored-by: schmidt-ju <[email protected]> Co-authored-by: Quan (Andy) Gan <[email protected]> Co-authored-by: Mufei Li <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: peizhou001 <[email protected]> Co-authored-by: Ubuntu <[email protected]>
1 parent 9892abd commit 05aebd8

7 files changed

+118
-87
lines changed
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""An abstract base class for cugraph-ops nn module."""
2+
import torch
3+
from torch import nn
4+
5+
6+
class CuGraphBaseConv(nn.Module):
7+
r"""An abstract base class for cugraph-ops nn module."""
8+
9+
def __init__(self):
10+
super().__init__()
11+
self._cached_offsets_fg = None
12+
13+
def reset_parameters(self):
14+
r"""Resets all learnable parameters of the module."""
15+
raise NotImplementedError
16+
17+
def forward(self, *args):
18+
r"""Runs the forward pass of the module."""
19+
raise NotImplementedError
20+
21+
def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:
22+
r"""Pad zero-in-degree nodes to the end of offsets to reach size.
23+
24+
cugraph-ops often provides two variants of aggregation functions for a
25+
specific model: one intended for sampled-graph use cases, one for
26+
full-graph ones. The former is in general more performant, however, it
27+
only works when the sample size (the max of in-degrees) is small (<200),
28+
due to the limit of GPU shared memory. For graphs with a larger max
29+
in-degree, we need to fall back to the full-graph option, which requires
30+
to convert a DGL block to a full graph. With the csc-representation,
31+
this is equivalent to pad zero-in-degree nodes to the end of the offsets
32+
array (also called indptr or colptr).
33+
34+
Parameters
35+
----------
36+
offsets :
37+
The (monotonically increasing) index pointer array in a CSC-format
38+
graph.
39+
size : int
40+
The length of offsets after padding.
41+
42+
Returns
43+
-------
44+
torch.Tensor
45+
The augmented offsets array.
46+
"""
47+
if self._cached_offsets_fg is None:
48+
self._cached_offsets_fg = torch.empty(
49+
size, dtype=offsets.dtype, device=offsets.device
50+
)
51+
elif self._cached_offsets_fg.numel() < size:
52+
self._cached_offsets_fg.resize_(size)
53+
54+
self._cached_offsets_fg[: offsets.numel()] = offsets
55+
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]
56+
57+
return self._cached_offsets_fg[:size]

python/dgl/nn/pytorch/conv/cugraph_gatconv.py

+19-24
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,27 @@
55
import torch
66
from torch import nn
77

8+
from .cugraph_base import CuGraphBaseConv
9+
810
try:
9-
from pylibcugraphops import make_fg_csr, make_mfg_csr
10-
from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg
11+
from pylibcugraphops.pytorch import SampledCSC, StaticCSC
12+
from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg
13+
14+
HAS_PYLIBCUGRAPHOPS = True
1115
except ImportError:
12-
has_pylibcugraphops = False
13-
else:
14-
has_pylibcugraphops = True
16+
HAS_PYLIBCUGRAPHOPS = False
1517

1618

17-
class CuGraphGATConv(nn.Module):
19+
class CuGraphGATConv(CuGraphBaseConv):
1820
r"""Graph attention layer from `Graph Attention Networks
1921
<https://arxiv.org/pdf/1710.10903.pdf>`__, with the sparse aggregation
2022
accelerated by cugraph-ops.
2123
2224
See :class:`dgl.nn.pytorch.conv.GATConv` for mathematical model.
2325
2426
This module depends on :code:`pylibcugraphops` package, which can be
25-
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
27+
installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.
28+
:code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.
2629
2730
.. note::
2831
This is an **experimental** feature.
@@ -78,7 +81,7 @@ class CuGraphGATConv(nn.Module):
7881
[ 1.6477, -1.9986],
7982
[ 1.1138, -1.9302]]], device='cuda:0', grad_fn=<ViewBackward0>)
8083
"""
81-
MAX_IN_DEGREE_MFG = 500
84+
MAX_IN_DEGREE_MFG = 200
8285

8386
def __init__(
8487
self,
@@ -91,10 +94,11 @@ def __init__(
9194
activation=None,
9295
bias=True,
9396
):
94-
if has_pylibcugraphops is False:
97+
if HAS_PYLIBCUGRAPHOPS is False:
9598
raise ModuleNotFoundError(
96-
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. "
97-
f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`."
99+
f"{self.__class__.__name__} requires pylibcugraphops=23.04. "
100+
f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."
101+
f"pylibcugraphops requires Python 3.8 or 3.10."
98102
)
99103
super().__init__()
100104
self.in_feats = in_feats
@@ -170,25 +174,17 @@ def forward(self, g, feat, max_in_degree=None):
170174
max_in_degree = g.in_degrees().max().item()
171175

172176
if max_in_degree < self.MAX_IN_DEGREE_MFG:
173-
_graph = make_mfg_csr(
174-
g.dstnodes(),
177+
_graph = SampledCSC(
175178
offsets,
176179
indices,
177180
max_in_degree,
178181
g.num_src_nodes(),
179182
)
180183
else:
181-
offsets_fg = torch.empty(
182-
g.num_src_nodes() + 1,
183-
dtype=offsets.dtype,
184-
device=offsets.device,
185-
)
186-
offsets_fg[: offsets.numel()] = offsets
187-
offsets_fg[offsets.numel() :] = offsets[-1]
188-
189-
_graph = make_fg_csr(offsets_fg, indices)
184+
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
185+
_graph = StaticCSC(offsets_fg, indices)
190186
else:
191-
_graph = make_fg_csr(offsets, indices)
187+
_graph = StaticCSC(offsets, indices)
192188

193189
feat = self.feat_drop(feat)
194190
feat_transformed = self.fc(feat)
@@ -199,7 +195,6 @@ def forward(self, g, feat, max_in_degree=None):
199195
self.num_heads,
200196
"LeakyReLU",
201197
self.negative_slope,
202-
add_own_node=False,
203198
concat_heads=True,
204199
)[: g.num_dst_nodes()].view(-1, self.num_heads, self.out_feats)
205200

python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py

+24-37
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66
import torch
77
from torch import nn
88

9+
from .cugraph_base import CuGraphBaseConv
10+
911
try:
10-
from pylibcugraphops import make_fg_csr_hg, make_mfg_csr_hg
11-
from pylibcugraphops.torch.autograd import (
12+
from pylibcugraphops.pytorch import SampledHeteroCSC, StaticHeteroCSC
13+
from pylibcugraphops.pytorch.operators import (
1214
agg_hg_basis_n2n_post as RelGraphConvAgg,
1315
)
16+
17+
HAS_PYLIBCUGRAPHOPS = True
1418
except ImportError:
15-
has_pylibcugraphops = False
16-
else:
17-
has_pylibcugraphops = True
19+
HAS_PYLIBCUGRAPHOPS = False
1820

1921

20-
class CuGraphRelGraphConv(nn.Module):
22+
class CuGraphRelGraphConv(CuGraphBaseConv):
2123
r"""An accelerated relational graph convolution layer from `Modeling
2224
Relational Data with Graph Convolutional Networks
2325
<https://arxiv.org/abs/1703.06103>`__ that leverages the highly-optimized
@@ -26,7 +28,8 @@ class CuGraphRelGraphConv(nn.Module):
2628
See :class:`dgl.nn.pytorch.conv.RelGraphConv` for mathematical model.
2729
2830
This module depends on :code:`pylibcugraphops` package, which can be
29-
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
31+
installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.
32+
:code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.
3033
3134
.. note::
3235
This is an **experimental** feature.
@@ -92,10 +95,11 @@ def __init__(
9295
dropout=0.0,
9396
apply_norm=False,
9497
):
95-
if has_pylibcugraphops is False:
98+
if HAS_PYLIBCUGRAPHOPS is False:
9699
raise ModuleNotFoundError(
97-
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02 "
98-
f"to be installed."
100+
f"{self.__class__.__name__} requires pylibcugraphops=23.04. "
101+
f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."
102+
f"pylibcugraphops requires Python 3.8 or 3.10."
99103
)
100104
super().__init__()
101105
self.in_feat = in_feat
@@ -176,53 +180,36 @@ def forward(self, g, feat, etypes, max_in_degree=None):
176180
torch.Tensor
177181
New node features. Shape: :math:`(|V|, D_{out})`.
178182
"""
179-
# Create csc-representation and cast etypes to int32.
180183
offsets, indices, edge_ids = g.adj_tensors("csc")
181184
edge_types_perm = etypes[edge_ids.long()].int()
182185

183-
# Create cugraph-ops graph.
184186
if g.is_block:
185187
if max_in_degree is None:
186188
max_in_degree = g.in_degrees().max().item()
187189

188190
if max_in_degree < self.MAX_IN_DEGREE_MFG:
189-
_graph = make_mfg_csr_hg(
190-
g.dstnodes(),
191+
_graph = SampledHeteroCSC(
191192
offsets,
192193
indices,
194+
edge_types_perm,
193195
max_in_degree,
194196
g.num_src_nodes(),
195-
n_node_types=0,
196-
n_edge_types=self.num_rels,
197-
out_node_types=None,
198-
in_node_types=None,
199-
edge_types=edge_types_perm,
197+
self.num_rels,
200198
)
201199
else:
202-
offsets_fg = torch.empty(
203-
g.num_src_nodes() + 1,
204-
dtype=offsets.dtype,
205-
device=offsets.device,
206-
)
207-
offsets_fg[: offsets.numel()] = offsets
208-
offsets_fg[offsets.numel() :] = offsets[-1]
209-
210-
_graph = make_fg_csr_hg(
200+
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
201+
_graph = StaticHeteroCSC(
211202
offsets_fg,
212203
indices,
213-
n_node_types=0,
214-
n_edge_types=self.num_rels,
215-
node_types=None,
216-
edge_types=edge_types_perm,
204+
edge_types_perm,
205+
self.num_rels,
217206
)
218207
else:
219-
_graph = make_fg_csr_hg(
208+
_graph = StaticHeteroCSC(
220209
offsets,
221210
indices,
222-
n_node_types=0,
223-
n_edge_types=self.num_rels,
224-
node_types=None,
225-
edge_types=edge_types_perm,
211+
edge_types_perm,
212+
self.num_rels,
226213
)
227214

228215
h = RelGraphConvAgg(

python/dgl/nn/pytorch/conv/cugraph_sageconv.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22
cugraph-ops"""
33
# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments
44

5-
import torch
65
from torch import nn
76

7+
from .cugraph_base import CuGraphBaseConv
8+
89
try:
9-
from pylibcugraphops import make_fg_csr, make_mfg_csr
10-
from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg
10+
from pylibcugraphops.pytorch import SampledCSC, StaticCSC
11+
from pylibcugraphops.pytorch.operators import agg_concat_n2n as SAGEConvAgg
12+
13+
HAS_PYLIBCUGRAPHOPS = True
1114
except ImportError:
12-
has_pylibcugraphops = False
13-
else:
14-
has_pylibcugraphops = True
15+
HAS_PYLIBCUGRAPHOPS = False
1516

1617

17-
class CuGraphSAGEConv(nn.Module):
18+
class CuGraphSAGEConv(CuGraphBaseConv):
1819
r"""An accelerated GraphSAGE layer from `Inductive Representation Learning
1920
on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__ that leverages the
2021
highly-optimized aggregation primitives in cugraph-ops:
@@ -27,7 +28,8 @@ class CuGraphSAGEConv(nn.Module):
2728
(h_{i}^{l}, h_{\mathcal{N}(i)}^{(l+1)})
2829
2930
This module depends on :code:`pylibcugraphops` package, which can be
30-
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
31+
installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.
32+
:code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.
3133
3234
.. note::
3335
This is an **experimental** feature.
@@ -74,10 +76,11 @@ def __init__(
7476
feat_drop=0.0,
7577
bias=True,
7678
):
77-
if has_pylibcugraphops is False:
79+
if HAS_PYLIBCUGRAPHOPS is False:
7880
raise ModuleNotFoundError(
79-
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. "
80-
f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`."
81+
f"{self.__class__.__name__} requires pylibcugraphops=23.04. "
82+
f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."
83+
f"pylibcugraphops requires Python 3.8 or 3.10."
8184
)
8285

8386
valid_aggr_types = {"max", "min", "mean", "sum"}
@@ -126,25 +129,17 @@ def forward(self, g, feat, max_in_degree=None):
126129
max_in_degree = g.in_degrees().max().item()
127130

128131
if max_in_degree < self.MAX_IN_DEGREE_MFG:
129-
_graph = make_mfg_csr(
130-
g.dstnodes(),
132+
_graph = SampledCSC(
131133
offsets,
132134
indices,
133135
max_in_degree,
134136
g.num_src_nodes(),
135137
)
136138
else:
137-
offsets_fg = torch.empty(
138-
g.num_src_nodes() + 1,
139-
dtype=offsets.dtype,
140-
device=offsets.device,
141-
)
142-
offsets_fg[: offsets.numel()] = offsets
143-
offsets_fg[offsets.numel() :] = offsets[-1]
144-
145-
_graph = make_fg_csr(offsets_fg, indices)
139+
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
140+
_graph = StaticCSC(offsets_fg, indices)
146141
else:
147-
_graph = make_fg_csr(offsets, indices)
142+
_graph = StaticCSC(offsets, indices)
148143

149144
feat = self.feat_drop(feat)
150145
h = SAGEConvAgg(feat, _graph, self.aggr)[: g.num_dst_nodes()]

tests/cugraph/cugraph-ops/test_cugraph_gatconv.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def generate_graph():
2424
return g
2525

2626

27-
@pytest.mark.skip()
2827
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
2928
def test_gatconv_equality(idtype_int, max_in_degree, num_heads, to_block):
3029
device = "cuda:0"

tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py

-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def generate_graph():
2727
return g
2828

2929

30-
@pytest.mark.skip()
3130
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
3231
def test_relgraphconv_equality(
3332
idtype_int, max_in_degree, num_bases, regularizer, self_loop, to_block

tests/cugraph/cugraph-ops/test_cugraph_sageconv.py

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def generate_graph():
2323
return g
2424

2525

26-
@pytest.mark.skip()
2726
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
2827
def test_SAGEConv_equality(idtype_int, max_in_degree, to_block):
2928
device = "cuda:0"

0 commit comments

Comments
 (0)