12
12
from ...convert import to_homogeneous
13
13
from ...heterograph import DGLGraph
14
14
from ..base import etype_str_to_tuple , etype_tuple_to_str , ORIGINAL_EDGE_ID
15
+ from ..sampling_graph import SamplingGraph
15
16
from .sampled_subgraph_impl import SampledSubgraphImpl
16
17
17
18
@@ -74,7 +75,7 @@ def __init__(
74
75
self .edge_type_to_id = edge_type_to_id
75
76
76
77
77
- class CSCSamplingGraph :
78
+ class CSCSamplingGraph ( SamplingGraph ) :
78
79
r"""Class for CSC sampling graph."""
79
80
80
81
def __repr__ (self ):
@@ -83,6 +84,7 @@ def __repr__(self):
83
84
def __init__ (
84
85
self , c_csc_graph : torch .ScriptObject , metadata : Optional [GraphMetadata ]
85
86
):
87
+ super ().__init__ ()
86
88
self ._c_csc_graph = c_csc_graph
87
89
self ._metadata = metadata
88
90
@@ -108,6 +110,54 @@ def total_num_edges(self) -> int:
108
110
"""
109
111
return self ._c_csc_graph .num_edges ()
110
112
113
+ @property
114
+ def num_nodes (self ) -> Union [int , Dict [str , int ]]:
115
+ """The number of nodes in the graph.
116
+ - If the graph is homogenous, returns an integer.
117
+ - If the graph is heterogenous, returns a dictionary.
118
+
119
+ Returns
120
+ -------
121
+ Union[int, Dict[str, int]]
122
+ The number of nodes. Integer indicates the total nodes number of a
123
+ homogenous graph; dict indicates nodes number per node types of a
124
+ heterogenous graph.
125
+
126
+ Examples
127
+ --------
128
+ >>> import dgl.graphbolt as gb, torch
129
+ >>> total_num_nodes = 5
130
+ >>> total_num_edges = 12
131
+ >>> ntypes = {"N0": 0, "N1": 1}
132
+ >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
133
+ ... "N1:R2:N0": 2, "N1:R3:N1": 3}
134
+ >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
135
+ >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
136
+ >>> node_type_offset = torch.LongTensor([0, 2, 5])
137
+ >>> type_per_edge = torch.LongTensor(
138
+ ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
139
+ >>> metadata = gb.GraphMetadata(ntypes, etypes)
140
+ >>> graph = gb.from_csc(indptr, indices, node_type_offset,
141
+ ... type_per_edge, None, metadata)
142
+ >>> print(graph.num_nodes)
143
+ {'N0': tensor(2), 'N1': tensor(3)}
144
+ """
145
+
146
+ offset = self .node_type_offset
147
+
148
+ # Homogenous.
149
+ if offset is None or self .metadata is None :
150
+ return self ._c_csc_graph .num_nodes ()
151
+
152
+ # Heterogenous
153
+ else :
154
+ num_nodes_per_type = {
155
+ _type : offset [_idx + 1 ] - offset [_idx ]
156
+ for _type , _idx in self .metadata .node_type_to_id .items ()
157
+ }
158
+
159
+ return num_nodes_per_type
160
+
111
161
@property
112
162
def csc_indptr (self ) -> torch .tensor :
113
163
"""Returns the indices pointer in the CSC graph.
@@ -312,8 +362,8 @@ def sample_neighbors(
312
362
without replacement. If True, a value can be selected multiple
313
363
times. Otherwise, each value can be selected only once.
314
364
probs_name: str, optional
315
- An optional string specifying the name of an edge attribute used a. This
316
- attribute tensor should contain (unnormalized) probabilities
365
+ An optional string specifying the name of an edge attribute used.
366
+ This attribute tensor should contain (unnormalized) probabilities
317
367
corresponding to each neighboring edge of a node. It must be a 1D
318
368
floating-point or boolean tensor, with the number of elements
319
369
equalling the total number of edges.
0 commit comments