Skip to content

Commit 4135b1b

Browse files
[Performance] Fused sampling with compaction (#5924)
Co-authored-by: Hesham Mostafa <[email protected]>
1 parent 4ceb0bf commit 4135b1b

File tree

15 files changed

+1280
-82
lines changed

15 files changed

+1280
-82
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import time
2+
3+
import dgl
4+
import dgl.function as fn
5+
6+
import numpy as np
7+
import torch
8+
9+
from .. import utils
10+
11+
12+
@utils.benchmark("time")
13+
@utils.parametrize_cpu("graph_name", ["livejournal", "reddit"])
14+
@utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"])
15+
@utils.parametrize("format", ["csr", "csc"])
16+
@utils.parametrize("seed_nodes_num", [200, 5000, 20000])
17+
@utils.parametrize("fanout", [5, 20, 40])
18+
def track_time(graph_name, format, seed_nodes_num, fanout):
19+
device = utils.get_bench_device()
20+
graph = utils.get_graph(graph_name, format).to(device)
21+
22+
edge_dir = "in" if format == "csc" else "out"
23+
seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
24+
seed_nodes = torch.from_numpy(seed_nodes).to(device)
25+
26+
# dry run
27+
for i in range(3):
28+
dgl.sampling.sample_neighbors_fused(
29+
graph, seed_nodes, fanout, edge_dir=edge_dir
30+
)
31+
32+
# timing
33+
with utils.Timer() as t:
34+
for i in range(50):
35+
dgl.sampling.sample_neighbors_fused(
36+
graph, seed_nodes, fanout, edge_dir=edge_dir
37+
)
38+
39+
return t.elapsed_secs / 50

include/dgl/aten/csr.h

+66
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,72 @@ COOMatrix CSRRowWiseSampling(
572572
CSRMatrix mat, IdArray rows, int64_t num_samples,
573573
NDArray prob_or_mask = NDArray(), bool replace = true);
574574

575+
/*!
576+
* @brief Randomly select a fixed number of non-zero entries along each given
577+
* row independently.
578+
*
579+
* The function performs random choices along each row independently.
580+
* The picked indices are returned in the form of a CSR matrix, with
581+
* additional IdArray that is an extended version of CSR's index pointers.
582+
*
583+
* With template parameter set to True rows are also saved as new seed nodes and
584+
* mapped
585+
*
586+
* If replace is false and a row has fewer non-zero values than num_samples,
587+
* all the values are picked.
588+
*
589+
* Examples:
590+
*
591+
* // csr.num_rows = 4;
592+
* // csr.num_cols = 4;
593+
* // csr.indptr = [0, 2, 3, 3, 5]
594+
* // csr.indices = [0, 1, 1, 2, 3]
595+
* // csr.data = [2, 3, 0, 1, 4]
596+
* CSRMatrix csr = ...;
597+
* IdArray rows = ... ; // [1, 3]
598+
* IdArray seed_mapping = [-1, -1, -1, -1];
599+
* std::vector<IdType> new_seed_nodes = {};
600+
*
601+
* std::pair<CSRMatrix, IdArray> sampled = CSRRowWiseSamplingFused<
602+
* typename IdType, True>(
603+
* csr, rows, seed_mapping,
604+
* new_seed_nodes, 2,
605+
* FloatArray(), false);
606+
* // possible sampled csr matrix:
607+
* // sampled.first.num_rows = 2
608+
* // sampled.first.num_cols = 3
609+
* // sampled.first.indptr = [0, 1, 3]
610+
* // sampled.first.indices = [1, 2, 3]
611+
* // sampled.first.data = [0, 1, 4]
612+
* // sampled.second = [0, 1, 1]
613+
* // seed_mapping = [-1, 0, -1, 1];
614+
* // new_seed_nodes = {1, 3};
615+
*
616+
* @tparam IdType Graph's index data type, can be int32_t or int64_t
617+
* @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes
618+
* @param mat Input CSR matrix.
619+
* @param rows Rows to sample from.
620+
* @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row
621+
* from rows will be set to its position e.g. mapping[rows[i]] = i.
622+
* @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will
623+
* contain rows.
624+
* @param rows Rows to sample from.
625+
* @param num_samples Number of samples
626+
* @param prob_or_mask Unnormalized probability array or mask array.
627+
* Should be of the same length as the data array.
628+
* If an empty array is provided, assume uniform.
629+
* @param replace True if sample with replacement
630+
* @return A CSRMatrix storing the picked row, col and data indices,
631+
* COO version of picked rows
632+
* @note The edges of the entire graph must be ordered by their edge types,
633+
* rows must be unique
634+
*/
635+
template <typename IdType, bool map_seed_nodes>
636+
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
637+
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
638+
std::vector<IdType>* new_seed_nodes, int64_t num_samples,
639+
NDArray prob_or_mask = NDArray(), bool replace = true);
640+
575641
/**
576642
* @brief Randomly select a fixed number of non-zero entries for each edge type
577643
* along each given row independently.

include/dgl/sampling/neighbor.h

+50
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <dgl/array.h>
1010
#include <dgl/base_heterograph.h>
1111

12+
#include <tuple>
1213
#include <vector>
1314

1415
namespace dgl {
@@ -47,6 +48,55 @@ HeteroSubgraph SampleNeighbors(
4748
const std::vector<FloatArray>& probability,
4849
const std::vector<IdArray>& exclude_edges, bool replace = true);
4950

51+
/**
52+
* @brief Sample from the neighbors of the given nodes and convert a graph into
53+
* a bipartite-structured graph for message passing.
54+
*
55+
* Specifically, we create one node type \c ntype_l on the "left" side and
56+
* another node type \c ntype_r on the "right" side for each node type \c ntype.
57+
* The nodes of type \c ntype_r would contain the nodes designated by the
58+
* caller, and node type \c ntype_l would contain the nodes that has an edge
59+
* connecting to one of the designated nodes.
60+
*
61+
* The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r.
62+
* When sampling with replacement, the sampled subgraph could have parallel
63+
* edges.
64+
*
65+
* For sampling without replace, if fanout > the number of neighbors, all the
66+
* neighbors will be sampled.
67+
*
68+
* Non-deterministic algorithm, requires nodes parameter to store unique Node
69+
* IDs.
70+
*
71+
* @tparam IdType Graph's index data type, can be int32_t or int64_t
72+
* @param hg The input graph.
73+
* @param nodes Node IDs of each type. The vector length must be equal to the
74+
* number of node types. Empty array is allowed.
75+
* @param mapping External parameter that should be set to a vector of IdArrays
76+
* filled with -1, required for mapping of nodes in returned
77+
* graph
78+
* @param fanouts Number of sampled neighbors for each edge type. The vector
79+
* length should be equal to the number of edge types, or one if they all have
80+
* the same fanout.
81+
* @param dir Edge direction.
82+
* @param probability A vector of 1D float arrays, indicating the transition
83+
* probability of each edge by edge type. An empty float array assumes uniform
84+
* transition.
85+
* @param exclude_edges Edges IDs of each type which will be excluded during
86+
* sampling. The vector length must be equal to the number of edges types. Empty
87+
* array is allowed.
88+
* @param replace If true, sample with replacement.
89+
* @return Sampled neighborhoods as a graph. The return graph has the same
90+
* schema as the original one.
91+
*/
92+
template <typename IdType>
93+
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
94+
SampleNeighborsFused(
95+
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
96+
const std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,
97+
EdgeDir dir, const std::vector<NDArray>& prob_or_mask,
98+
const std::vector<IdArray>& exclude_edges, bool replace = true);
99+
50100
/**
51101
* Select the neighbors with k-largest weights on the connecting edges for each
52102
* given node.

python/dgl/dataloading/neighbor_sampler.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Data loading components for neighbor sampling"""
2+
from .. import backend as F
23
from ..base import EID, NID
4+
from ..heterograph import DGLGraph
35
from ..transforms import to_block
46
from .base import BlockSampler
57

@@ -54,6 +56,9 @@ class NeighborSampler(BlockSampler):
5456
output_device : device, optional
5557
The device of the output subgraphs or MFGs. Default is the same as the
5658
minibatch of seed nodes.
59+
fused : bool, default True
60+
If True and device is CPU fused sample neighbors is invoked. This version
61+
requires seed_nodes to be unique
5762
5863
Examples
5964
--------
@@ -120,6 +125,7 @@ def __init__(
120125
prefetch_labels=None,
121126
prefetch_edge_feats=None,
122127
output_device=None,
128+
fused=True,
123129
):
124130
super().__init__(
125131
prefetch_node_feats=prefetch_node_feats,
@@ -137,10 +143,43 @@ def __init__(
137143
)
138144
self.prob = prob or mask
139145
self.replace = replace
146+
self.fused = fused
147+
self.mapping = {}
148+
self.g = None
140149

141150
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
142151
output_nodes = seed_nodes
143152
blocks = []
153+
154+
if self.fused:
155+
cpu = F.device_type(g.device) == "cpu"
156+
if isinstance(seed_nodes, dict):
157+
for ntype in list(seed_nodes.keys()):
158+
if not cpu:
159+
break
160+
cpu = (
161+
cpu and F.device_type(seed_nodes[ntype].device) == "cpu"
162+
)
163+
else:
164+
cpu = cpu and F.device_type(seed_nodes.device) == "cpu"
165+
if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch":
166+
if self.g != g:
167+
self.mapping = {}
168+
self.g = g
169+
for fanout in reversed(self.fanouts):
170+
block = g.sample_neighbors_fused(
171+
seed_nodes,
172+
fanout,
173+
edge_dir=self.edge_dir,
174+
prob=self.prob,
175+
replace=self.replace,
176+
exclude_edges=exclude_eids,
177+
mapping=self.mapping,
178+
)
179+
seed_nodes = block.srcdata[NID]
180+
blocks.insert(0, block)
181+
return seed_nodes, output_nodes, blocks
182+
144183
for fanout in reversed(self.fanouts):
145184
frontier = g.sample_neighbors(
146185
seed_nodes,

0 commit comments

Comments
 (0)