Skip to content

Commit 324fd97

Browse files
authored
[GraphBolt][CUDA] Minibatch.to() patch. (#7330)
1 parent 7de2e51 commit 324fd97

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

python/dgl/graphbolt/internal/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,21 @@ def copy_or_convert_data(
129129
save_data(data, output_path, output_format)
130130

131131

132+
def get_nonproperty_attributes(_obj) -> list:
133+
"""Get attributes of the class except for the properties."""
134+
attributes = [
135+
attribute
136+
for attribute in dir(_obj)
137+
if not attribute.startswith("__")
138+
and (
139+
not hasattr(type(_obj), attribute)
140+
or not isinstance(getattr(type(_obj), attribute), property)
141+
)
142+
and not callable(getattr(_obj, attribute))
143+
]
144+
return attributes
145+
146+
132147
def get_attributes(_obj) -> list:
133148
"""Get attributes of the class."""
134149
attributes = [

python/dgl/graphbolt/minibatch.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dgl.utils import recursive_apply
1010

1111
from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
12-
from .internal import get_attributes
12+
from .internal import get_attributes, get_nonproperty_attributes
1313
from .sampled_subgraph import SampledSubgraph
1414

1515
__all__ = ["MiniBatch"]
@@ -556,23 +556,14 @@ def to_pyg_data(self):
556556
def to(self, device: torch.device): # pylint: disable=invalid-name
557557
"""Copy `MiniBatch` to the specified device using reflection."""
558558

559-
def _to(x, device):
559+
def _to(x):
560560
return x.to(device) if hasattr(x, "to") else x
561561

562-
def apply_to(x, device):
563-
return recursive_apply(x, lambda x: _to(x, device))
564-
565-
transfer_attrs = get_attributes(self)
562+
transfer_attrs = get_nonproperty_attributes(self)
566563

567564
for attr in transfer_attrs:
568565
# Only copy member variables.
569-
try:
570-
# For read-only attributes such as blocks , setattr will throw
571-
# an AttributeError. We catch these exceptions and skip those
572-
# attributes.
573-
setattr(self, attr, apply_to(getattr(self, attr), device))
574-
except AttributeError:
575-
continue
566+
setattr(self, attr, recursive_apply(getattr(self, attr), _to))
576567

577568
return self
578569

0 commit comments

Comments
 (0)