@@ -32,20 +32,23 @@ class UniformNegativeSampler(NegativeSampler):
32
32
Examples
33
33
--------
34
34
>>> from dgl import graphbolt as gb
35
- >>> indptr = torch.LongTensor([0, 2, 4, 5 ])
36
- >>> indices = torch.LongTensor([1, 2, 0, 2 , 0])
35
+ >>> indptr = torch.LongTensor([0, 1, 2, 3, 4 ])
36
+ >>> indices = torch.LongTensor([1, 2, 3 , 0])
37
37
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
38
- >>> node_pairs = ( torch.tensor([0, 1]), torch.tensor( [1, 2]) )
38
+ >>> node_pairs = torch.tensor([[ 0, 1], [1, 2], [2, 3], [3, 0]] )
39
39
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
40
40
>>> item_sampler = gb.ItemSampler(
41
- ... item_set, batch_size=1 ,)
41
+ ... item_set, batch_size=4 ,)
42
42
>>> neg_sampler = gb.UniformNegativeSampler(
43
43
... item_sampler, graph, 2)
44
44
>>> for minibatch in neg_sampler:
45
45
... print(minibatch.negative_srcs)
46
46
... print(minibatch.negative_dsts)
47
- (tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
48
- (tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
47
+ None
48
+ tensor([[2, 1],
49
+ [2, 1],
50
+ [3, 2],
51
+ [1, 3]])
49
52
"""
50
53
51
54
def __init__ (
0 commit comments