Skip to content

Commit 7ec78bb

Browse files
frozenbugsUbuntu
and
Ubuntu
authored
[Graphbolt] change dataset method to property. (#6023)
Co-authored-by: Ubuntu <[email protected]>
1 parent 4135b1b commit 7ec78bb

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

python/dgl/graphbolt/dataset.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,27 @@ class Dataset:
3131
generate a subgraph.
3232
"""
3333

34+
@property
3435
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
3536
"""Return the training sets."""
3637
raise NotImplementedError
3738

39+
@property
3840
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
3941
"""Return the validation sets."""
4042
raise NotImplementedError
4143

44+
@property
4245
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
4346
"""Return the test sets."""
4447
raise NotImplementedError
4548

49+
@property
4650
def graph(self) -> object:
4751
"""Return the graph."""
4852
raise NotImplementedError
4953

54+
@property
5055
def feature(self) -> Dict[object, FeatureStore]:
5156
"""Return the feature."""
5257
raise NotImplementedError

python/dgl/graphbolt/impl/ondisk_dataset.py

+5
Original file line numberDiff line numberDiff line change
@@ -281,22 +281,27 @@ def __init__(self, path: str) -> None:
281281
self._validation_sets = self._init_tvt_sets(self._meta.validation_sets)
282282
self._test_sets = self._init_tvt_sets(self._meta.test_sets)
283283

284+
@property
284285
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
285286
"""Return the training set."""
286287
return self._train_sets
287288

289+
@property
288290
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
289291
"""Return the validation set."""
290292
return self._validation_sets
291293

294+
@property
292295
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
293296
"""Return the test set."""
294297
return self._test_sets
295298

299+
@property
296300
def graph(self) -> object:
297301
"""Return the graph."""
298302
return self._graph
299303

304+
@property
300305
def feature(self) -> Dict[Tuple, TorchBasedFeatureStore]:
301306
"""Return the feature."""
302307
return self._feature

tests/python/pytorch/graphbolt/test_ondisk_dataset.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
106106
dataset = gb.OnDiskDataset(yaml_file)
107107

108108
# Verify train set.
109-
train_sets = dataset.train_sets()
109+
train_sets = dataset.train_sets
110110
assert len(train_sets) == 2
111111
for train_set in train_sets:
112112
assert len(train_set) == 1000
@@ -117,7 +117,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
117117
train_sets = None
118118

119119
# Verify validation set.
120-
validation_sets = dataset.validation_sets()
120+
validation_sets = dataset.validation_sets
121121
assert len(validation_sets) == 2
122122
for validation_set in validation_sets:
123123
assert len(validation_set) == 1000
@@ -128,7 +128,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
128128
validation_sets = None
129129

130130
# Verify test set.
131-
test_sets = dataset.test_sets()
131+
test_sets = dataset.test_sets
132132
assert len(test_sets) == 2
133133
for test_set in test_sets:
134134
assert len(test_set) == 1000
@@ -151,9 +151,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
151151
f.write(yaml_content)
152152

153153
dataset = gb.OnDiskDataset(yaml_file)
154-
assert dataset.train_sets() is not None
155-
assert dataset.validation_sets() is None
156-
assert dataset.test_sets() is None
154+
assert dataset.train_sets is not None
155+
assert dataset.validation_sets is None
156+
assert dataset.test_sets is None
157157
dataset = None
158158

159159

@@ -209,7 +209,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
209209
dataset = gb.OnDiskDataset(yaml_file)
210210

211211
# Verify train set.
212-
train_sets = dataset.train_sets()
212+
train_sets = dataset.train_sets
213213
assert len(train_sets) == 2
214214
for train_set in train_sets:
215215
assert len(train_set) == 1000
@@ -221,7 +221,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
221221
train_sets = None
222222

223223
# Verify validation set.
224-
validation_sets = dataset.validation_sets()
224+
validation_sets = dataset.validation_sets
225225
assert len(validation_sets) == 2
226226
for validation_set in validation_sets:
227227
assert len(validation_set) == 1000
@@ -233,7 +233,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
233233
validation_sets = None
234234

235235
# Verify test set.
236-
test_sets = dataset.test_sets()
236+
test_sets = dataset.test_sets
237237
assert len(test_sets) == 2
238238
for test_set in test_sets:
239239
assert len(test_set) == 1000
@@ -299,7 +299,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
299299
dataset = gb.OnDiskDataset(yaml_file)
300300

301301
# Verify train set.
302-
train_sets = dataset.train_sets()
302+
train_sets = dataset.train_sets
303303
assert len(train_sets) == 2
304304
for train_set in train_sets:
305305
assert len(train_set) == 1000
@@ -315,7 +315,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
315315
train_sets = None
316316

317317
# Verify validation set.
318-
validation_sets = dataset.validation_sets()
318+
validation_sets = dataset.validation_sets
319319
assert len(validation_sets) == 2
320320
for validation_set in validation_sets:
321321
assert len(validation_set) == 1000
@@ -331,7 +331,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
331331
validation_sets = None
332332

333333
# Verify test set.
334-
test_sets = dataset.test_sets()
334+
test_sets = dataset.test_sets
335335
assert len(test_sets) == 2
336336
for test_set in test_sets:
337337
assert len(test_set) == 1000
@@ -401,7 +401,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
401401
dataset = gb.OnDiskDataset(yaml_file)
402402

403403
# Verify train set.
404-
train_sets = dataset.train_sets()
404+
train_sets = dataset.train_sets
405405
assert len(train_sets) == 2
406406
for train_set in train_sets:
407407
assert len(train_set) == 1000
@@ -418,7 +418,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
418418
train_sets = None
419419

420420
# Verify validation set.
421-
validation_sets = dataset.validation_sets()
421+
validation_sets = dataset.validation_sets
422422
assert len(validation_sets) == 2
423423
for validation_set in validation_sets:
424424
assert len(validation_set) == 1000
@@ -435,7 +435,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
435435
validation_sets = None
436436

437437
# Verify test set.
438-
test_sets = dataset.test_sets()
438+
test_sets = dataset.test_sets
439439
assert len(test_sets) == 2
440440
for test_set in test_sets:
441441
assert len(test_set) == 1000
@@ -507,7 +507,7 @@ def test_OnDiskDataset_Feature_heterograph():
507507
dataset = gb.OnDiskDataset(yaml_file)
508508

509509
# Verify feature data storage.
510-
feature_data = dataset.feature()
510+
feature_data = dataset.feature
511511
assert len(feature_data) == 4
512512

513513
# Verify node feature data.
@@ -595,7 +595,7 @@ def test_OnDiskDataset_Feature_homograph():
595595
dataset = gb.OnDiskDataset(yaml_file)
596596

597597
# Verify feature data storage.
598-
feature_data = dataset.feature()
598+
feature_data = dataset.feature
599599
assert len(feature_data) == 4
600600

601601
# Verify node feature data.
@@ -661,7 +661,7 @@ def test_OnDiskDataset_Graph_homogeneous():
661661
f.write(yaml_content)
662662

663663
dataset = gb.OnDiskDataset(yaml_file)
664-
graph2 = dataset.graph()
664+
graph2 = dataset.graph
665665

666666
assert graph.num_nodes == graph2.num_nodes
667667
assert graph.num_edges == graph2.num_edges
@@ -703,7 +703,7 @@ def test_OnDiskDataset_Graph_heterogeneous():
703703
f.write(yaml_content)
704704

705705
dataset = gb.OnDiskDataset(yaml_file)
706-
graph2 = dataset.graph()
706+
graph2 = dataset.graph
707707

708708
assert graph.num_nodes == graph2.num_nodes
709709
assert graph.num_edges == graph2.num_edges

0 commit comments

Comments
 (0)