Skip to content

Commit 9116f67

Browse files
authored
[Graphbolt]Unique and compact OP (#6098)
1 parent 88964a8 commit 9116f67

File tree

5 files changed

+172
-60
lines changed

5 files changed

+172
-60
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/**
2+
* Copyright (c) 2023 by Contributors
3+
*
4+
* @file unique_and_compact.h
5+
* @brief Unique and compact op.
6+
*/
7+
#ifndef GRAPHBOLT_UNIQUE_AND_COMPACT_H_
8+
#define GRAPHBOLT_UNIQUE_AND_COMPACT_H_
9+
10+
#include <torch/torch.h>
11+
12+
namespace graphbolt {
13+
namespace sampling {
14+
/**
15+
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
16+
* 'src_ids' tensor and applies the uniqueness information to compact both
17+
* source and destination tensors.
18+
*
19+
* The function performs two main operations:
20+
* 1. Unique Operation: 'unique(concat(unique_dst_ids, src_ids))', in which
21+
* the unique operator will guarantee the 'unique_dst_ids' are at the head of
22+
* the result tensor.
23+
* 2. Compact Operation: Utilizes the reverse mapping derived from the unique
24+
* operation to transform 'src_ids' and 'dst_ids' into compacted IDs.
25+
*
26+
* @param src_ids A tensor containing source IDs.
27+
* @param dst_ids A tensor containing destination IDs.
28+
* @param unique_dst_ids A tensor containing unique destination IDs, which is
29+
* exactly all the unique elements in 'dst_ids'.
30+
*
31+
* @return
32+
* - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after
33+
* removing duplicates. The indices in this tensor precisely match the compacted
34+
* IDs of the corresponding elements.
35+
* - The tensor corresponding to the 'src_ids' tensor, where the entries are
36+
* mapped to compacted IDs.
37+
* - The tensor corresponding to the 'dst_ids' tensor, where the entries are
38+
* mapped to compacted IDs.
39+
*
40+
* @example
41+
* torch::Tensor src_ids = src
42+
* torch::Tensor dst_ids = dst
43+
* torch::Tensor unique_dst_ids = torch::unique(dst);
44+
* auto result = UniqueAndCompact(src_ids, dst_ids, unique_dst_ids);
45+
* torch::Tensor unique_ids = std::get<0>(result);
46+
* torch::Tensor compacted_src_ids = std::get<1>(result);
47+
* torch::Tensor compacted_dst_ids = std::get<2>(result);
48+
*/
49+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
50+
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
51+
const torch::Tensor unique_dst_ids);
52+
53+
} // namespace sampling
54+
} // namespace graphbolt
55+
56+
#endif // GRAPHBOLT_UNIQUE_AND_COMPACT_H_

graphbolt/src/python_binding.cc

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <graphbolt/csc_sampling_graph.h>
88
#include <graphbolt/serialize.h>
9+
#include <graphbolt/unique_and_compact.h>
910

1011
namespace graphbolt {
1112
namespace sampling {
@@ -39,6 +40,7 @@ TORCH_LIBRARY(graphbolt, m) {
3940
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
4041
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
4142
m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory);
43+
m.def("unique_and_compact", &UniqueAndCompact);
4244
}
4345

4446
} // namespace sampling

graphbolt/src/unique_and_compact.cc

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
* Copyright (c) 2023 by Contributors
3+
*
4+
* @file unique_and_compact.cc
5+
* @brief Unique and compact op.
6+
*/
7+
8+
#include <graphbolt/unique_and_compact.h>
9+
10+
#include "./concurrent_id_hash_map.h"
11+
12+
namespace graphbolt {
13+
namespace sampling {
14+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
15+
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
16+
const torch::Tensor unique_dst_ids) {
17+
torch::Tensor compacted_src_ids;
18+
torch::Tensor compacted_dst_ids;
19+
torch::Tensor unique_ids;
20+
auto num_dst = unique_dst_ids.size(0);
21+
torch::Tensor ids = torch::cat({unique_dst_ids, src_ids});
22+
AT_DISPATCH_INTEGRAL_TYPES(ids.scalar_type(), "unique_and_compact", ([&] {
23+
ConcurrentIdHashMap<scalar_t> id_map;
24+
unique_ids = id_map.Init(ids, num_dst);
25+
compacted_src_ids = id_map.MapIds(src_ids);
26+
compacted_dst_ids = id_map.MapIds(dst_ids);
27+
}));
28+
return std::tuple(unique_ids, compacted_src_ids, compacted_dst_ids);
29+
}
30+
} // namespace sampling
31+
} // namespace graphbolt

python/dgl/graphbolt/utils/sample_utils.py

+57-33
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ def unique_and_compact_node_pairs(
1010
node_pairs: Union[
1111
Tuple[torch.Tensor, torch.Tensor],
1212
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
13-
]
13+
],
14+
unique_dst_nodes: Union[
15+
torch.Tensor,
16+
Dict[str, torch.Tensor],
17+
] = None,
1418
):
1519
"""
1620
Compact node pairs and return unique nodes (per type).
@@ -26,6 +30,11 @@ def unique_and_compact_node_pairs(
2630
- If `node_pairs` is a dictionary: The keys should be edge type and
2731
the values should be corresponding node pairs. And IDs inside are
2832
heterogeneous ids.
33+
unique_dst_nodes: torch.Tensor or Dict[str, torch.Tensor]
34+
Unique nodes of all destination nodes in the node pairs.
35+
- If `unique_dst_nodes` is a tensor: It means the graph is homogeneous.
36+
- If `node_pairs` is a dictionary: The keys are node type and the
37+
values are corresponding nodes. And IDs inside are heterogeneous ids.
2938
3039
Returns
3140
-------
@@ -52,44 +61,59 @@ def unique_and_compact_node_pairs(
5261
{('n1', 'e1', 'n2'): (tensor([0, 1, 1]), tensor([0, 1, 0])),
5362
('n2', 'e2', 'n1'): (tensor([0, 1, 0]), tensor([0, 1, 1]))}
5463
"""
55-
is_homogeneous = not isinstance(node_pairs, Dict)
64+
is_homogeneous = not isinstance(node_pairs, dict)
5665
if is_homogeneous:
5766
node_pairs = {("_N", "_E", "_N"): node_pairs}
58-
nodes_dict = defaultdict(list)
59-
# Collect nodes for each node type.
60-
for etype, node_pair in node_pairs.items():
61-
u_type, _, v_type = etype
62-
u, v = node_pair
63-
nodes_dict[u_type].append(u)
64-
nodes_dict[v_type].append(v)
67+
if unique_dst_nodes is not None:
68+
assert isinstance(
69+
unique_dst_nodes, torch.Tensor
70+
), "Edge type not supported in homogeneous graph."
71+
unique_dst_nodes = {"_N": unique_dst_nodes}
6572

66-
unique_nodes_dict = {}
67-
inverse_indices_dict = {}
68-
for ntype, nodes in nodes_dict.items():
69-
collected_nodes = torch.cat(nodes)
70-
# Compact and find unique nodes.
71-
unique_nodes, inverse_indices = torch.unique(
72-
collected_nodes,
73-
return_inverse=True,
74-
)
75-
unique_nodes_dict[ntype] = unique_nodes
76-
inverse_indices_dict[ntype] = inverse_indices
73+
# Collect all source and destination nodes for each node type.
74+
src_nodes = defaultdict(list)
75+
dst_nodes = defaultdict(list)
76+
for etype, (src_node, dst_node) in node_pairs.items():
77+
src_nodes[etype[0]].append(src_node)
78+
dst_nodes[etype[2]].append(dst_node)
79+
src_nodes = {ntype: torch.cat(nodes) for ntype, nodes in src_nodes.items()}
80+
dst_nodes = {ntype: torch.cat(nodes) for ntype, nodes in dst_nodes.items()}
81+
# Compute unique destination nodes if not provided.
82+
if unique_dst_nodes is None:
83+
unique_dst_nodes = {
84+
ntype: torch.unique(nodes) for ntype, nodes in dst_nodes.items()
85+
}
86+
87+
ntypes = set(dst_nodes.keys()) | set(src_nodes.keys())
88+
unique_nodes = {}
89+
compacted_src = {}
90+
compacted_dst = {}
91+
dtype = list(src_nodes.values())[0].dtype
92+
default_tensor = torch.tensor([], dtype=dtype)
93+
for ntype in ntypes:
94+
src = src_nodes.get(ntype, default_tensor)
95+
unique_dst = unique_dst_nodes.get(ntype, default_tensor)
96+
dst = dst_nodes.get(ntype, default_tensor)
97+
(
98+
unique_nodes[ntype],
99+
compacted_src[ntype],
100+
compacted_dst[ntype],
101+
) = torch.ops.graphbolt.unique_and_compact(src, dst, unique_dst)
77102

78-
# Map back in same order as collect.
79103
compacted_node_pairs = {}
80-
unique_nodes = unique_nodes_dict
81-
for etype, node_pair in node_pairs.items():
82-
u_type, _, v_type = etype
83-
u, v = node_pair
84-
u_size, v_size = u.numel(), v.numel()
85-
u = inverse_indices_dict[u_type][:u_size]
86-
inverse_indices_dict[u_type] = inverse_indices_dict[u_type][u_size:]
87-
v = inverse_indices_dict[v_type][:v_size]
88-
inverse_indices_dict[v_type] = inverse_indices_dict[v_type][v_size:]
89-
compacted_node_pairs[etype] = (u, v)
104+
# Map back with the same order.
105+
for etype, pair in node_pairs.items():
106+
num_elem = pair[0].size(0)
107+
src_type, _, dst_type = etype
108+
src = compacted_src[src_type][:num_elem]
109+
dst = compacted_dst[dst_type][:num_elem]
110+
compacted_node_pairs[etype] = (src, dst)
111+
compacted_src[src_type] = compacted_src[src_type][num_elem:]
112+
compacted_dst[dst_type] = compacted_dst[dst_type][num_elem:]
90113

91-
# Return singleton for homogeneous graph.
114+
# Return singleton for a homogeneous graph.
92115
if is_homogeneous:
93116
compacted_node_pairs = list(compacted_node_pairs.values())[0]
94-
unique_nodes = list(unique_nodes_dict.values())[0]
117+
unique_nodes = list(unique_nodes.values())[0]
118+
95119
return unique_nodes, compacted_node_pairs
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,20 @@
11
import dgl.graphbolt as gb
2+
import pytest
23
import torch
34

45

56
def test_unique_and_compact_node_pairs_hetero():
67
N1 = torch.randint(0, 50, (30,))
78
N2 = torch.randint(0, 50, (20,))
89
N3 = torch.randint(0, 50, (10,))
9-
unique_N1, compacted_N1 = torch.unique(N1, return_inverse=True)
10-
unique_N2, compacted_N2 = torch.unique(N2, return_inverse=True)
11-
unique_N3, compacted_N3 = torch.unique(N3, return_inverse=True)
10+
unique_N1 = torch.unique(N1)
11+
unique_N2 = torch.unique(N2)
12+
unique_N3 = torch.unique(N3)
1213
expected_unique_nodes = {
1314
"n1": unique_N1,
1415
"n2": unique_N2,
1516
"n3": unique_N3,
1617
}
17-
expected_compacted_pairs = {
18-
("n1", "e1", "n2"): (
19-
compacted_N1[:20],
20-
compacted_N2,
21-
),
22-
("n1", "e2", "n3"): (
23-
compacted_N1[20:30],
24-
compacted_N3,
25-
),
26-
("n2", "e3", "n3"): (
27-
compacted_N2[10:],
28-
compacted_N3,
29-
),
30-
}
3118
node_pairs = {
3219
("n1", "e1", "n2"): (
3320
N1[:20],
@@ -46,27 +33,39 @@ def test_unique_and_compact_node_pairs_hetero():
4633
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
4734
node_pairs
4835
)
36+
for ntype, nodes in unique_nodes.items():
37+
expected_nodes = expected_unique_nodes[ntype]
38+
assert torch.equal(torch.sort(nodes)[0], expected_nodes)
4939
for etype, pair in compacted_node_pairs.items():
50-
expected_u, expected_v = expected_compacted_pairs[etype]
5140
u, v = pair
41+
u_type, _, v_type = etype
42+
u, v = unique_nodes[u_type][u], unique_nodes[v_type][v]
43+
expected_u, expected_v = node_pairs[etype]
5244
assert torch.equal(u, expected_u)
5345
assert torch.equal(v, expected_v)
54-
for ntype, nodes in unique_nodes.items():
55-
expected_nodes = expected_unique_nodes[ntype]
56-
assert torch.equal(nodes, expected_nodes)
5746

5847

5948
def test_unique_and_compact_node_pairs_homo():
60-
N = torch.randint(0, 50, (20,))
61-
expected_unique_N, compacted_N = torch.unique(N, return_inverse=True)
62-
expected_compacted_pairs = tuple(compacted_N.split(10))
49+
N = torch.randint(0, 50, (200,))
50+
expected_unique_N = torch.unique(N)
6351

64-
node_pairs = tuple(N.split(10))
52+
node_pairs = tuple(N.split(100))
6553
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
6654
node_pairs
6755
)
68-
expected_u, expected_v = expected_compacted_pairs
56+
assert torch.equal(torch.sort(unique_nodes)[0], expected_unique_N)
57+
6958
u, v = compacted_node_pairs
59+
u, v = unique_nodes[u], unique_nodes[v]
60+
expected_u, expected_v = node_pairs
61+
unique_v = torch.unique(expected_v)
7062
assert torch.equal(u, expected_u)
7163
assert torch.equal(v, expected_v)
72-
assert torch.equal(unique_nodes, expected_unique_N)
64+
assert torch.equal(unique_nodes[: unique_v.size(0)], unique_v)
65+
66+
67+
def test_incomplete_unique_dst_nodes_():
68+
node_pairs = (torch.randint(0, 50, (50,)), torch.randint(100, 150, (50,)))
69+
unique_dst_nodes = torch.arange(150, 200)
70+
with pytest.raises(IndexError):
71+
gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)

0 commit comments

Comments
 (0)