Skip to content

Commit 12841c6

Browse files
authored
[Dist] backward compatible with dgl.dataloading.DistDataLoader (#7782)
1 parent 32b11c9 commit 12841c6

File tree

5 files changed

+80
-1
lines changed

5 files changed

+80
-1
lines changed

docs/source/api/python/dgl.distributed.rst

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ Distributed Sampling
6868
Distributed DataLoader
6969
``````````````````````
7070

71+
.. autoclass:: NodeCollator
72+
73+
.. autoclass:: EdgeCollator
74+
7175
.. autoclass:: DistDataLoader
7276

7377
.. autoclass:: DistNodeDataLoader

python/dgl/dataloading/dataloader.py

+42
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,48 @@ def set_epoch(self, epoch):
14741474
raise DGLError("set_epoch is only available when use_ddp is True.")
14751475

14761476

1477+
class NodeCollator:
1478+
"""Deprecated. Please use :class:`~dgl.distributed.NodeCollator` instead."""
1479+
1480+
def __new__(cls, *args, **kwargs):
1481+
dgl_warning(
1482+
"NodeCollator is defined in dgl.distributed This class is for "
1483+
"backward compatibility and will be removed soon. Please update "
1484+
"your code to use `dgl.distributed.NodeCollator`."
1485+
)
1486+
from ..distributed import NodeCollator as NewNodeCollator
1487+
1488+
return NewNodeCollator(*args, **kwargs)
1489+
1490+
1491+
class EdgeCollator:
1492+
"""Deprecated. Please use :class:`~dgl.distributed.EdgeCollator` instead."""
1493+
1494+
def __new__(cls, *args, **kwargs):
1495+
dgl_warning(
1496+
"EdgeCollator is defined in dgl.distributed This class is for "
1497+
"backward compatibility and will be removed soon. Please update "
1498+
"your code to use `dgl.distributed.EdgeCollator`."
1499+
)
1500+
from ..distributed import EdgeCollator as NewEdgeCollator
1501+
1502+
return NewEdgeCollator(*args, **kwargs)
1503+
1504+
1505+
class DistDataLoader:
1506+
"""Deprecated. Please use :class:`~dgl.distributed.DistDataLoader` instead."""
1507+
1508+
def __new__(cls, *args, **kwargs):
1509+
dgl_warning(
1510+
"DistDataLoader is defined in dgl.distributed This class is for "
1511+
"backward compatibility and will be removed soon. Please update "
1512+
"your code to use `dgl.distributed.DistDataLoader`."
1513+
)
1514+
from ..distributed import DistDataLoader as NewDistDataLoader
1515+
1516+
return NewDistDataLoader(*args, **kwargs)
1517+
1518+
14771519
class DistNodeDataLoader:
14781520
"""Deprecated. Please use :class:`~dgl.distributed.DistNodeDataLoader`
14791521
instead.

python/dgl/distributed/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
DistDataLoader,
77
DistEdgeDataLoader,
88
DistNodeDataLoader,
9+
EdgeCollator,
10+
NodeCollator,
911
)
1012
from .dist_graph import DistGraph, DistGraphServer, edge_split, node_split
1113
from .dist_tensor import DistTensor

python/dgl/distributed/dist_dataloader.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from ..convert import heterograph
1010
from .dist_context import get_sampler_pool
1111

12-
__all__ = ["DistDataLoader", "DistNodeDataLoader", "DistEdgeDataLoader"]
12+
__all__ = [
13+
"NodeCollator",
14+
"EdgeCollator",
15+
"DistDataLoader",
16+
"DistNodeDataLoader",
17+
"DistEdgeDataLoader",
18+
]
1319

1420
DATALOADER_ID = 0
1521

tests/python/pytorch/dataloading/test_dataloader.py

+25
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,31 @@ def test_dataloader_worker_init_fn():
821821
pass
822822

823823

824+
def test_distributed_dataloaders():
825+
# Test distributed dataloaders could be successfully imported.
826+
try:
827+
from dgl.dataloading import (
828+
DistDataLoader,
829+
DistEdgeDataLoader,
830+
DistNodeDataLoader,
831+
EdgeCollator,
832+
NodeCollator,
833+
)
834+
except ImportError:
835+
pytest.fail("Distributed DataLoader from dataloading import failed")
836+
837+
try:
838+
from dgl.distributed import (
839+
DistDataLoader,
840+
DistEdgeDataLoader,
841+
DistNodeDataLoader,
842+
EdgeCollator,
843+
NodeCollator,
844+
)
845+
except ImportError:
846+
pytest.fail("Distributed DataLoader from dataloading import failed")
847+
848+
824849
if __name__ == "__main__":
825850
# test_node_dataloader(F.int32, 'neighbor', None)
826851
test_edge_dataloader_excludes(

0 commit comments

Comments
 (0)