Skip to content

Commit 06074d7

Browse files
authored
[GraphBolt] enrich node types for input/output nodes of sampled subgraph (#6348)
1 parent adf4993 commit 06074d7

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

python/dgl/graphbolt/impl/csc_sampling_graph.py

-2
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,6 @@ def _convert_to_sampled_subgraph(
242242
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
243243
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
244244
mask = type_per_edge == etype_id
245-
if mask.count_nonzero() == 0:
246-
continue
247245
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
248246
hetero_column = (
249247
column[mask] - self.node_type_offset[dst_ntype_id]

python/dgl/graphbolt/impl/neighbor_sampler.py

+8
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
3
7979
"""
8080
super().__init__(datapipe)
81+
self.graph = graph
8182
# Convert fanouts to a list of tensors.
8283
self.fanouts = []
8384
for fanout in fanouts:
@@ -91,6 +92,13 @@ def __init__(
9192
def _sample_subgraphs(self, seeds):
9293
subgraphs = []
9394
num_layers = len(self.fanouts)
95+
# Enrich seeds with all node types.
96+
if isinstance(seeds, dict):
97+
ntypes = list(self.graph.metadata.node_type_to_id.keys())
98+
seeds = {
99+
ntype: seeds.get(ntype, torch.LongTensor([]))
100+
for ntype in ntypes
101+
}
94102
for hop in range(num_layers):
95103
subgraph = self.sampler(
96104
seeds,

tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,12 @@ def test_sample_neighbors_hetero(labor):
584584
torch.LongTensor([0, 2]),
585585
torch.LongTensor([0, 0]),
586586
),
587+
"n1:e1:n2": (
588+
torch.LongTensor([]),
589+
torch.LongTensor([]),
590+
),
587591
}
588-
assert len(subgraph.node_pairs) == 1
592+
assert len(subgraph.node_pairs) == 2
589593
for etype, pairs in expected_node_pairs.items():
590594
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
591595
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])

tests/python/pytorch/graphbolt/test_subgraph_sampler.py

+17
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,23 @@ def get_hetero_graph():
129129
)
130130

131131

132+
@pytest.mark.parametrize("labor", [False, True])
133+
def test_SubgraphSampler_Node_Hetero(labor):
134+
graph = get_hetero_graph()
135+
itemset = gb.ItemSetDict(
136+
{"n2": gb.ItemSet(torch.arange(3), names="seed_nodes")}
137+
)
138+
item_sampler = gb.ItemSampler(itemset, batch_size=2)
139+
num_layer = 2
140+
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
141+
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
142+
sampler_dp = Sampler(item_sampler, graph, fanouts)
143+
assert len(list(sampler_dp)) == 2
144+
for minibatch in sampler_dp:
145+
blocks = minibatch.to_dgl_blocks()
146+
assert len(blocks) == num_layer
147+
148+
132149
@pytest.mark.parametrize("labor", [False, True])
133150
def test_SubgraphSampler_Link_Hetero(labor):
134151
graph = get_hetero_graph()

0 commit comments

Comments
 (0)