Skip to content

Commit 92c8f08

Browse files
Rhett-Yingpeizhou001Ubuntu
authored
[Graphbolt]Fix negative sampler (#6933) (#6938)
Co-authored-by: peizhou001 <[email protected]> Co-authored-by: Ubuntu <[email protected]>
1 parent c047950 commit 92c8f08

File tree

7 files changed

+98
-151
lines changed

7 files changed

+98
-151
lines changed

graphbolt/include/graphbolt/fused_csc_sampling_graph.h

-26
Original file line numberDiff line numberDiff line change
@@ -356,32 +356,6 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
356356
torch::optional<std::string> node_timestamp_attr_name,
357357
torch::optional<std::string> edge_timestamp_attr_name) const;
358358

359-
/**
360-
* @brief Sample negative edges by randomly choosing negative
361-
* source-destination pairs according to a uniform distribution. For each edge
362-
* ``(u, v)``, it is supposed to generate `negative_ratio` pairs of negative
363-
* edges ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
364-
* the graph.
365-
*
366-
* @param node_pairs A tuple of two 1D tensors that represent the source and
367-
* destination of positive edges, with 'positive' indicating that these edges
368-
* are present in the graph. It's important to note that within the context of
369-
* a heterogeneous graph, the ids in these tensors signify heterogeneous ids.
370-
* @param negative_ratio The ratio of the number of negative samples to
371-
* positive samples.
372-
* @param max_node_id The maximum ID of the node to be selected. It
373-
* should correspond to the number of nodes of a specific type.
374-
*
375-
* @return A tuple consisting of two 1D tensors represents the source and
376-
* destination of negative edges. In the context of a heterogeneous
377-
* graph, both the input nodes and the selected nodes are represented
378-
* by heterogeneous IDs. Note that negative refers to false negatives,
379-
* which means the edge could be present or not present in the graph.
380-
*/
381-
std::tuple<torch::Tensor, torch::Tensor> SampleNegativeEdgesUniform(
382-
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
383-
int64_t negative_ratio, int64_t max_node_id) const;
384-
385359
/**
386360
* @brief Copy the graph to shared memory.
387361
* @param shared_memory_name The name of the shared memory.

graphbolt/src/fused_csc_sampling_graph.cc

-12
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,6 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
692692
edge_timestamp));
693693
}
694694

695-
std::tuple<torch::Tensor, torch::Tensor>
696-
FusedCSCSamplingGraph::SampleNegativeEdgesUniform(
697-
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
698-
int64_t negative_ratio, int64_t max_node_id) const {
699-
torch::Tensor pos_src;
700-
std::tie(pos_src, std::ignore) = node_pairs;
701-
auto neg_len = pos_src.size(0) * negative_ratio;
702-
auto neg_src = pos_src.repeat(negative_ratio);
703-
auto neg_dst = torch::randint(0, max_node_id, {neg_len}, pos_src.options());
704-
return std::make_tuple(neg_src, neg_dst);
705-
}
706-
707695
static c10::intrusive_ptr<FusedCSCSamplingGraph>
708696
BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
709697
helper.InitializeRead();

graphbolt/src/python_binding.cc

-3
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ TORCH_LIBRARY(graphbolt, m) {
5252
.def(
5353
"temporal_sample_neighbors",
5454
&FusedCSCSamplingGraph::TemporalSampleNeighbors)
55-
.def(
56-
"sample_negative_edges_uniform",
57-
&FusedCSCSamplingGraph::SampleNegativeEdgesUniform)
5855
.def("copy_to_shared_memory", &FusedCSCSamplingGraph::CopyToSharedMemory)
5956
.def_pickle(
6057
// __getstate__

python/dgl/graphbolt/impl/fused_csc_sampling_graph.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,8 @@ def sample_negative_edges_uniform(
850850
pairs according to a uniform distribution. For each edge ``(u, v)``,
851851
it is supposed to generate `negative_ratio` pairs of negative edges
852852
``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
853-
the graph.
853+
the graph. As ``u`` is exactly same as the corresponding positive edges,
854+
it returns None for negative sources.
854855
855856
Parameters
856857
----------
@@ -877,23 +878,22 @@ def sample_negative_edges_uniform(
877878
`edge_type`. Note that negative refers to false negatives, which
878879
means the edge could be present or not present in the graph.
879880
"""
880-
if edge_type is not None:
881-
assert (
882-
self.node_type_offset is not None
883-
), "The 'node_type_offset' array is necessary for performing \
884-
negative sampling by edge type."
885-
_, _, dst_node_type = etype_str_to_tuple(edge_type)
886-
dst_node_type_id = self.node_type_to_id[dst_node_type]
887-
max_node_id = (
888-
self.node_type_offset[dst_node_type_id + 1]
889-
- self.node_type_offset[dst_node_type_id]
890-
)
881+
if edge_type:
882+
_, _, dst_ntype = etype_str_to_tuple(edge_type)
883+
max_node_id = self.num_nodes[dst_ntype]
891884
else:
892885
max_node_id = self.total_num_nodes
893-
return self._c_csc_graph.sample_negative_edges_uniform(
894-
node_pairs,
895-
negative_ratio,
896-
max_node_id,
886+
pos_src, _ = node_pairs
887+
num_negative = pos_src.size(0) * negative_ratio
888+
return (
889+
None,
890+
torch.randint(
891+
0,
892+
max_node_id,
893+
(num_negative,),
894+
dtype=pos_src.dtype,
895+
device=pos_src.device,
896+
),
897897
)
898898

899899
def copy_to_shared_memory(self, shared_memory_name: str):

python/dgl/graphbolt/impl/uniform_negative_sampler.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,23 @@ class UniformNegativeSampler(NegativeSampler):
3232
Examples
3333
--------
3434
>>> from dgl import graphbolt as gb
35-
>>> indptr = torch.LongTensor([0, 2, 4, 5])
36-
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
35+
>>> indptr = torch.LongTensor([0, 1, 2, 3, 4])
36+
>>> indices = torch.LongTensor([1, 2, 3, 0])
3737
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
38-
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
38+
>>> node_pairs = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
3939
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
4040
>>> item_sampler = gb.ItemSampler(
41-
... item_set, batch_size=1,)
41+
... item_set, batch_size=4,)
4242
>>> neg_sampler = gb.UniformNegativeSampler(
4343
... item_sampler, graph, 2)
4444
>>> for minibatch in neg_sampler:
4545
... print(minibatch.negative_srcs)
4646
... print(minibatch.negative_dsts)
47-
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
48-
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
47+
None
48+
tensor([[2, 1],
49+
[2, 1],
50+
[3, 2],
51+
[1, 3]])
4952
"""
5053

5154
def __init__(

tests/python/pytorch/graphbolt/impl/test_negative_sampler.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def test_UniformNegativeSampler_invoke():
4646
def _verify(negative_sampler):
4747
for data in negative_sampler:
4848
# Assertation
49-
assert data.negative_srcs.size(0) == batch_size
50-
assert data.negative_srcs.size(1) == negative_ratio
49+
assert data.negative_srcs is None
5150
assert data.negative_dsts.size(0) == batch_size
5251
assert data.negative_dsts.size(1) == negative_ratio
5352

@@ -90,12 +89,9 @@ def test_Uniform_NegativeSampler(negative_ratio):
9089
# Assertation
9190
assert len(pos_src) == batch_size
9291
assert len(pos_dst) == batch_size
93-
assert len(neg_src) == batch_size
9492
assert len(neg_dst) == batch_size
95-
assert neg_src.numel() == batch_size * negative_ratio
93+
assert neg_src is None
9694
assert neg_dst.numel() == batch_size * negative_ratio
97-
expected_src = pos_src.repeat(negative_ratio).view(-1, negative_ratio)
98-
assert torch.equal(expected_src, neg_src)
9995

10096

10197
def get_hetero_graph():

0 commit comments

Comments
 (0)