|
9 | 9 | from dgl.utils import recursive_apply
|
10 | 10 |
|
11 | 11 | from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
|
12 |
| -from .internal import get_attributes |
| 12 | +from .internal import get_attributes, get_nonproperty_attributes |
13 | 13 | from .sampled_subgraph import SampledSubgraph
|
14 | 14 |
|
15 | 15 | __all__ = ["MiniBatch"]
|
@@ -556,23 +556,14 @@ def to_pyg_data(self):
|
556 | 556 | def to(self, device: torch.device): # pylint: disable=invalid-name
|
557 | 557 | """Copy `MiniBatch` to the specified device using reflection."""
|
558 | 558 |
|
559 |
| - def _to(x, device): |
| 559 | + def _to(x): |
560 | 560 | return x.to(device) if hasattr(x, "to") else x
|
561 | 561 |
|
562 |
| - def apply_to(x, device): |
563 |
| - return recursive_apply(x, lambda x: _to(x, device)) |
564 |
| - |
565 |
| - transfer_attrs = get_attributes(self) |
| 562 | + transfer_attrs = get_nonproperty_attributes(self) |
566 | 563 |
|
567 | 564 | for attr in transfer_attrs:
|
568 | 565 | # Only copy member variables.
|
569 |
| - try: |
570 |
| - # For read-only attributes such as blocks , setattr will throw |
571 |
| - # an AttributeError. We catch these exceptions and skip those |
572 |
| - # attributes. |
573 |
| - setattr(self, attr, apply_to(getattr(self, attr), device)) |
574 |
| - except AttributeError: |
575 |
| - continue |
| 566 | + setattr(self, attr, recursive_apply(getattr(self, attr), _to)) |
576 | 567 |
|
577 | 568 | return self
|
578 | 569 |
|
|
0 commit comments