Skip to content

Commit cc079e1

Browse files
Skeleton003Ubuntufrozenbugs
authored
[GraphBolt] Add implementation of num_nodes (#6395)
Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Hongzhi (Steve), Chen <[email protected]>
1 parent fdb4737 commit cc079e1

File tree

3 files changed

+131
-18
lines changed

3 files changed

+131
-18
lines changed

python/dgl/graphbolt/impl/csc_sampling_graph.py

+53-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ...convert import to_homogeneous
1313
from ...heterograph import DGLGraph
1414
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
15+
from ..sampling_graph import SamplingGraph
1516
from .sampled_subgraph_impl import SampledSubgraphImpl
1617

1718

@@ -74,7 +75,7 @@ def __init__(
7475
self.edge_type_to_id = edge_type_to_id
7576

7677

77-
class CSCSamplingGraph:
78+
class CSCSamplingGraph(SamplingGraph):
7879
r"""Class for CSC sampling graph."""
7980

8081
def __repr__(self):
@@ -83,6 +84,7 @@ def __repr__(self):
8384
def __init__(
8485
self, c_csc_graph: torch.ScriptObject, metadata: Optional[GraphMetadata]
8586
):
87+
super().__init__()
8688
self._c_csc_graph = c_csc_graph
8789
self._metadata = metadata
8890

@@ -108,6 +110,54 @@ def total_num_edges(self) -> int:
108110
"""
109111
return self._c_csc_graph.num_edges()
110112

113+
@property
114+
def num_nodes(self) -> Union[int, Dict[str, int]]:
115+
"""The number of nodes in the graph.
116+
- If the graph is homogenous, returns an integer.
117+
- If the graph is heterogenous, returns a dictionary.
118+
119+
Returns
120+
-------
121+
Union[int, Dict[str, int]]
122+
The number of nodes. Integer indicates the total nodes number of a
123+
homogenous graph; dict indicates nodes number per node types of a
124+
heterogenous graph.
125+
126+
Examples
127+
--------
128+
>>> import dgl.graphbolt as gb, torch
129+
>>> total_num_nodes = 5
130+
>>> total_num_edges = 12
131+
>>> ntypes = {"N0": 0, "N1": 1}
132+
>>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
133+
... "N1:R2:N0": 2, "N1:R3:N1": 3}
134+
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
135+
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
136+
>>> node_type_offset = torch.LongTensor([0, 2, 5])
137+
>>> type_per_edge = torch.LongTensor(
138+
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
139+
>>> metadata = gb.GraphMetadata(ntypes, etypes)
140+
>>> graph = gb.from_csc(indptr, indices, node_type_offset,
141+
... type_per_edge, None, metadata)
142+
>>> print(graph.num_nodes)
143+
{'N0': tensor(2), 'N1': tensor(3)}
144+
"""
145+
146+
offset = self.node_type_offset
147+
148+
# Homogenous.
149+
if offset is None or self.metadata is None:
150+
return self._c_csc_graph.num_nodes()
151+
152+
# Heterogenous
153+
else:
154+
num_nodes_per_type = {
155+
_type: offset[_idx + 1] - offset[_idx]
156+
for _type, _idx in self.metadata.node_type_to_id.items()
157+
}
158+
159+
return num_nodes_per_type
160+
111161
@property
112162
def csc_indptr(self) -> torch.tensor:
113163
"""Returns the indices pointer in the CSC graph.
@@ -312,8 +362,8 @@ def sample_neighbors(
312362
without replacement. If True, a value can be selected multiple
313363
times. Otherwise, each value can be selected only once.
314364
probs_name: str, optional
315-
An optional string specifying the name of an edge attribute used a. This
316-
attribute tensor should contain (unnormalized) probabilities
365+
An optional string specifying the name of an edge attribute used.
366+
This attribute tensor should contain (unnormalized) probabilities
317367
corresponding to each neighboring edge of a node. It must be a 1D
318368
floating-point or boolean tensor, with the number of elements
319369
equalling the total number of edges.

python/dgl/graphbolt/sampling_graph.py

-15
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,3 @@ def num_nodes(self) -> Union[int, Dict[str, int]]:
2323
heterogenous graph.
2424
"""
2525
raise NotImplementedError
26-
27-
@property
28-
def num_edges(self) -> Union[int, Dict[str, int]]:
29-
"""The number of edges in the graph.
30-
- If the graph is homogenous, returns an integer.
31-
- If the graph is heterogenous, returns a dictionary.
32-
33-
Returns
34-
-------
35-
Union[int, Dict[str, int]]
36-
The number of edges. Integer indicates the total edges number of a
37-
homogenous graph; dict indicates edges number per edge types of a
38-
heterogenous graph.
39-
"""
40-
raise NotImplementedError

tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py

+78
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,84 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
177177
assert metadata.edge_type_to_id == graph.metadata.edge_type_to_id
178178

179179

180+
@unittest.skipIf(
181+
F._default_context_str == "gpu",
182+
reason="Graph is CPU only at present.",
183+
)
184+
@pytest.mark.parametrize(
185+
"total_num_nodes, total_num_edges",
186+
[(1, 1), (100, 1), (10, 50), (1000, 50000)],
187+
)
188+
def test_num_nodes_homo(total_num_nodes, total_num_edges):
189+
csc_indptr, indices = gbt.random_homo_graph(
190+
total_num_nodes, total_num_edges
191+
)
192+
edge_attributes = {
193+
"A1": torch.randn(total_num_edges),
194+
"A2": torch.randn(total_num_edges),
195+
}
196+
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes)
197+
198+
assert graph.num_nodes == total_num_nodes
199+
200+
201+
@unittest.skipIf(
202+
F._default_context_str == "gpu",
203+
reason="Graph is CPU only at present.",
204+
)
205+
def test_num_nodes_hetero():
206+
"""Original graph in COO:
207+
1 0 1 0 1
208+
1 0 1 1 0
209+
0 1 0 1 0
210+
0 1 0 0 1
211+
1 0 0 0 1
212+
213+
node_type_0: [0, 1]
214+
node_type_1: [2, 3, 4]
215+
edge_type_0: node_type_0 -> node_type_0
216+
edge_type_1: node_type_0 -> node_type_1
217+
edge_type_2: node_type_1 -> node_type_0
218+
edge_type_3: node_type_1 -> node_type_1
219+
"""
220+
# Initialize data.
221+
total_num_nodes = 5
222+
total_num_edges = 12
223+
ntypes = {
224+
"N0": 0,
225+
"N1": 1,
226+
}
227+
etypes = {
228+
"N0:R0:N0": 0,
229+
"N0:R1:N1": 1,
230+
"N1:R2:N0": 2,
231+
"N1:R3:N1": 3,
232+
}
233+
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
234+
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
235+
node_type_offset = torch.LongTensor([0, 2, 5])
236+
type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
237+
assert indptr[-1] == total_num_edges
238+
assert indptr[-1] == len(indices)
239+
assert node_type_offset[-1] == total_num_nodes
240+
assert all(type_per_edge < len(etypes))
241+
242+
# Construct CSCSamplingGraph.
243+
metadata = gb.GraphMetadata(ntypes, etypes)
244+
graph = gb.from_csc(
245+
indptr, indices, node_type_offset, type_per_edge, None, metadata
246+
)
247+
248+
# Verify nodes number per node types.
249+
assert graph.num_nodes == {
250+
"N0": 2,
251+
"N1": 3,
252+
}
253+
assert graph.num_nodes["N0"] == 2
254+
assert graph.num_nodes["N1"] == 3
255+
assert "N2" not in graph.num_nodes
256+
257+
180258
@unittest.skipIf(
181259
F._default_context_str == "gpu",
182260
reason="Graph is CPU only at present.",

0 commit comments

Comments
 (0)