From 56d74329cc3893c855824e539d6197c2b45ff2c5 Mon Sep 17 00:00:00 2001 From: e-yi Date: Mon, 21 Oct 2019 18:03:44 +0800 Subject: [PATCH] fix a mistake about negative sampling --- model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index 8616204..1ad6a89 100644 --- a/model.py +++ b/model.py @@ -12,7 +12,8 @@ def __init__(self, node_size, path_size, embed_dim, sigmoid_reg=False, r=True): # self.args = args def binary_reg(x: torch.Tensor): - return (x >= 0).float() + raise NotImplementedError() + # return (x >= 0).float() # do not have gradients self.reg = torch.sigmoid if sigmoid_reg else binary_reg @@ -78,18 +79,23 @@ def __init__(self, sample, path_size, neg=5): :param sample: HIN.sample()返回值,(start_node, end_node, path_id) """ + print('init training dataset...') + l = len(sample) x = np.tile(sample, (neg + 1, 1)) y = np.zeros(l * (1 + neg)) y[:l] = 1 - x[l:, 2] = np.random.randint(0, path_size - 1, (l * neg,)) + # x[l:, 2] = np.random.randint(0, path_size - 1, (l * neg,)) + x[l:, 1] = np.random.randint(0, path_size - 1, (l * neg,)) self.x = torch.LongTensor(x) self.y = torch.FloatTensor(y) self.length = len(x) + print('finished') + def __getitem__(self, index): return self.x[index], self.y[index]