Skip to content

Commit 6f2ccbf

Browse files
authored
[GraphBolt][CUDA] Make dataloader pickleable. (#7391)
1 parent 7a2e873 commit 6f2ccbf

File tree

4 files changed

+29
-12
lines changed

4 files changed

+29
-12
lines changed

examples/sampling/graphbolt/link_prediction.py

-2
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,6 @@ def train(args, model, graph, features, train_set):
328328

329329
total_loss += loss.item()
330330
if step + 1 == args.early_stop:
331-
# Early stopping requires a new dataloader to reset its state.
332-
dataloader = create_dataloader(args, graph, features, train_set)
333331
break
334332

335333
end_epoch_time = time.time()

notebooks/stochastic_training/link_prediction.ipynb

+9-10
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,13 @@
129129
"outputs": [],
130130
"source": [
131131
"from functools import partial\n",
132-
"def create_train_dataloader():\n",
133-
" datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)\n",
134-
" datapipe = datapipe.copy_to(device)\n",
135-
" datapipe = datapipe.sample_uniform_negative(graph, 5)\n",
136-
" datapipe = datapipe.sample_neighbor(graph, [5, 5])\n",
137-
" datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))\n",
138-
" datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
139-
" return gb.DataLoader(datapipe)"
132+
"datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)\n",
133+
"datapipe = datapipe.copy_to(device)\n",
134+
"datapipe = datapipe.sample_uniform_negative(graph, 5)\n",
135+
"datapipe = datapipe.sample_neighbor(graph, [5, 5])\n",
136+
"datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))\n",
137+
"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
138+
"train_dataloader = gb.DataLoader(datapipe)"
140139
]
141140
},
142141
{
@@ -157,7 +156,7 @@
157156
},
158157
"outputs": [],
159158
"source": [
160-
"data = next(iter(create_train_dataloader()))\n",
159+
"data = next(iter(train_dataloader))\n",
161160
"print(f\"MiniBatch: {data}\")"
162161
]
163162
},
@@ -253,7 +252,7 @@
253252
"for epoch in range(3):\n",
254253
" model.train()\n",
255254
" total_loss = 0\n",
256-
" for step, data in tqdm(enumerate(create_train_dataloader())):\n",
255+
" for step, data in tqdm(enumerate(train_dataloader)):\n",
257256
" # Get node pairs with labels for loss calculation.\n",
258257
" compacted_seeds = data.compacted_seeds.T\n",
259258
" labels = data.labels\n",

python/dgl/graphbolt/base.py

+14
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,20 @@ def __iter__(self):
313313
while len(self.buffer) > 0:
314314
yield self.buffer.popleft()
315315

316+
def __getstate__(self):
317+
state = (self.datapipe, self.buffer.maxlen)
318+
if IterDataPipe.getstate_hook is not None:
319+
return IterDataPipe.getstate_hook(state)
320+
return state
321+
322+
def __setstate__(self, state):
323+
self.datapipe, buffer_size = state
324+
self.buffer = deque(maxlen=buffer_size)
325+
326+
def reset(self):
327+
"""Resets the state of the datapipe."""
328+
self.buffer.clear()
329+
316330

317331
@functional_datapipe("wait")
318332
class Waiter(IterDataPipe):

tests/python/pytorch/graphbolt/test_dataloader.py

+6
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,9 @@ def test_gpu_sampling_DataLoader(
111111
)
112112
assert len(bufferers) == bufferer_awaiter_cnt
113113
assert len(list(dataloader)) == N // B
114+
115+
for i, _ in enumerate(dataloader):
116+
if i >= 1:
117+
break
118+
119+
assert len(list(dataloader)) == N // B

0 commit comments

Comments
 (0)