@@ -10,7 +10,11 @@ def unique_and_compact_node_pairs(
10
10
node_pairs : Union [
11
11
Tuple [torch .Tensor , torch .Tensor ],
12
12
Dict [Tuple [str , str , str ], Tuple [torch .Tensor , torch .Tensor ]],
13
- ]
13
+ ],
14
+ unique_dst_nodes : Union [
15
+ torch .Tensor ,
16
+ Dict [str , torch .Tensor ],
17
+ ] = None ,
14
18
):
15
19
"""
16
20
Compact node pairs and return unique nodes (per type).
@@ -26,6 +30,11 @@ def unique_and_compact_node_pairs(
26
30
- If `node_pairs` is a dictionary: The keys should be edge type and
27
31
the values should be corresponding node pairs. And IDs inside are
28
32
heterogeneous ids.
33
+ unique_dst_nodes: torch.Tensor or Dict[str, torch.Tensor]
34
+ Unique nodes of all destination nodes in the node pairs.
35
+ - If `unique_dst_nodes` is a tensor: It means the graph is homogeneous.
36
+ - If `node_pairs` is a dictionary: The keys are node type and the
37
+ values are corresponding nodes. And IDs inside are heterogeneous ids.
29
38
30
39
Returns
31
40
-------
@@ -52,44 +61,59 @@ def unique_and_compact_node_pairs(
52
61
{('n1', 'e1', 'n2'): (tensor([0, 1, 1]), tensor([0, 1, 0])),
53
62
('n2', 'e2', 'n1'): (tensor([0, 1, 0]), tensor([0, 1, 1]))}
54
63
"""
55
- is_homogeneous = not isinstance (node_pairs , Dict )
64
+ is_homogeneous = not isinstance (node_pairs , dict )
56
65
if is_homogeneous :
57
66
node_pairs = {("_N" , "_E" , "_N" ): node_pairs }
58
- nodes_dict = defaultdict (list )
59
- # Collect nodes for each node type.
60
- for etype , node_pair in node_pairs .items ():
61
- u_type , _ , v_type = etype
62
- u , v = node_pair
63
- nodes_dict [u_type ].append (u )
64
- nodes_dict [v_type ].append (v )
67
+ if unique_dst_nodes is not None :
68
+ assert isinstance (
69
+ unique_dst_nodes , torch .Tensor
70
+ ), "Edge type not supported in homogeneous graph."
71
+ unique_dst_nodes = {"_N" : unique_dst_nodes }
65
72
66
- unique_nodes_dict = {}
67
- inverse_indices_dict = {}
68
- for ntype , nodes in nodes_dict .items ():
69
- collected_nodes = torch .cat (nodes )
70
- # Compact and find unique nodes.
71
- unique_nodes , inverse_indices = torch .unique (
72
- collected_nodes ,
73
- return_inverse = True ,
74
- )
75
- unique_nodes_dict [ntype ] = unique_nodes
76
- inverse_indices_dict [ntype ] = inverse_indices
73
+ # Collect all source and destination nodes for each node type.
74
+ src_nodes = defaultdict (list )
75
+ dst_nodes = defaultdict (list )
76
+ for etype , (src_node , dst_node ) in node_pairs .items ():
77
+ src_nodes [etype [0 ]].append (src_node )
78
+ dst_nodes [etype [2 ]].append (dst_node )
79
+ src_nodes = {ntype : torch .cat (nodes ) for ntype , nodes in src_nodes .items ()}
80
+ dst_nodes = {ntype : torch .cat (nodes ) for ntype , nodes in dst_nodes .items ()}
81
+ # Compute unique destination nodes if not provided.
82
+ if unique_dst_nodes is None :
83
+ unique_dst_nodes = {
84
+ ntype : torch .unique (nodes ) for ntype , nodes in dst_nodes .items ()
85
+ }
86
+
87
+ ntypes = set (dst_nodes .keys ()) | set (src_nodes .keys ())
88
+ unique_nodes = {}
89
+ compacted_src = {}
90
+ compacted_dst = {}
91
+ dtype = list (src_nodes .values ())[0 ].dtype
92
+ default_tensor = torch .tensor ([], dtype = dtype )
93
+ for ntype in ntypes :
94
+ src = src_nodes .get (ntype , default_tensor )
95
+ unique_dst = unique_dst_nodes .get (ntype , default_tensor )
96
+ dst = dst_nodes .get (ntype , default_tensor )
97
+ (
98
+ unique_nodes [ntype ],
99
+ compacted_src [ntype ],
100
+ compacted_dst [ntype ],
101
+ ) = torch .ops .graphbolt .unique_and_compact (src , dst , unique_dst )
77
102
78
- # Map back in same order as collect.
79
103
compacted_node_pairs = {}
80
- unique_nodes = unique_nodes_dict
81
- for etype , node_pair in node_pairs .items ():
82
- u_type , _ , v_type = etype
83
- u , v = node_pair
84
- u_size , v_size = u .numel (), v .numel ()
85
- u = inverse_indices_dict [u_type ][:u_size ]
86
- inverse_indices_dict [u_type ] = inverse_indices_dict [u_type ][u_size :]
87
- v = inverse_indices_dict [v_type ][:v_size ]
88
- inverse_indices_dict [v_type ] = inverse_indices_dict [v_type ][v_size :]
89
- compacted_node_pairs [etype ] = (u , v )
104
+ # Map back with the same order.
105
+ for etype , pair in node_pairs .items ():
106
+ num_elem = pair [0 ].size (0 )
107
+ src_type , _ , dst_type = etype
108
+ src = compacted_src [src_type ][:num_elem ]
109
+ dst = compacted_dst [dst_type ][:num_elem ]
110
+ compacted_node_pairs [etype ] = (src , dst )
111
+ compacted_src [src_type ] = compacted_src [src_type ][num_elem :]
112
+ compacted_dst [dst_type ] = compacted_dst [dst_type ][num_elem :]
90
113
91
- # Return singleton for homogeneous graph.
114
+ # Return singleton for a homogeneous graph.
92
115
if is_homogeneous :
93
116
compacted_node_pairs = list (compacted_node_pairs .values ())[0 ]
94
- unique_nodes = list (unique_nodes_dict .values ())[0 ]
117
+ unique_nodes = list (unique_nodes .values ())[0 ]
118
+
95
119
return unique_nodes , compacted_node_pairs
0 commit comments