@@ -195,8 +195,7 @@ def apply_to(x, device):
195
195
class CopyTo (IterDataPipe ):
196
196
"""DataPipe that transfers each element yielded from the previous DataPipe
197
197
to the given device. For MiniBatch, only the related attributes
198
- (automatically inferred) will be transferred by default. If you want to
199
- transfer any other attributes, indicate them in the ``extra_attrs``.
198
+ (automatically inferred) will be transferred by default.
200
199
201
200
Functional name: :obj:`copy_to`.
202
201
@@ -208,64 +207,22 @@ class CopyTo(IterDataPipe):
208
207
for data in datapipe:
209
208
yield data.to(device)
210
209
211
- For :class:`~dgl.graphbolt.MiniBatch`, only a part of attributes will be
212
- transferred to accelerate the process by default:
213
-
214
- - When ``seed_nodes`` is not None and ``node_pairs`` is None, node related
215
- task is inferred. Only ``labels``, ``sampled_subgraphs``, ``node_features``
216
- and ``edge_features`` will be transferred.
217
-
218
- - When ``node_pairs`` is not None and ``seed_nodes`` is None, edge/link
219
- related task is inferred. Only ``labels``, ``compacted_node_pairs``,
220
- ``compacted_negative_srcs``, ``compacted_negative_dsts``,
221
- ``sampled_subgraphs``, ``node_features`` and ``edge_features`` will be
222
- transferred.
223
-
224
- - When ``seeds`` is not None, only ``labels``, ``compacted_seeds``,
225
- ``sampled_subgraphs``, ``node_features`` and ``edge_features`` will be
226
- transferred.
227
-
228
- - Otherwise, all attributes will be transferred.
229
-
230
- - If you want some other attributes to be transferred as well, please
231
- specify the name in the ``extra_attrs``. For instance, the following code
232
- will copy ``seed_nodes`` to the GPU as well:
233
-
234
- .. code:: python
235
-
236
- datapipe = datapipe.copy_to(device="cuda", extra_attrs=["seed_nodes"])
237
-
238
210
Parameters
239
211
----------
240
212
datapipe : DataPipe
241
213
The DataPipe.
242
214
device : torch.device
243
215
The PyTorch CUDA device.
244
- extra_attrs: List[string]
245
- The extra attributes of the data in the DataPipe you want to be carried
246
- to the specific device. The attributes specified in the ``extra_attrs``
247
- will be transferred regardless of the task inferred. It could also be
248
- applied to classes other than :class:`~dgl.graphbolt.MiniBatch`.
249
216
"""
250
217
251
- def __init__ (self , datapipe , device , extra_attrs = None ):
218
+ def __init__ (self , datapipe , device ):
252
219
super ().__init__ ()
253
220
self .datapipe = datapipe
254
221
self .device = device
255
- self .extra_attrs = extra_attrs
256
222
257
223
def __iter__ (self ):
258
224
for data in self .datapipe :
259
225
data = recursive_apply (data , apply_to , self .device )
260
- if self .extra_attrs is not None :
261
- for attr in self .extra_attrs :
262
- setattr (
263
- data ,
264
- attr ,
265
- recursive_apply (
266
- getattr (data , attr ), apply_to , self .device
267
- ),
268
- )
269
226
yield data
270
227
271
228
0 commit comments