From 37b8e30aae101c372c20cfd6ae51d79f1786c3e4 Mon Sep 17 00:00:00 2001 From: e-yi Date: Tue, 13 Jul 2021 13:55:52 +0800 Subject: [PATCH] add binary activation function --- model.py | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/model.py b/model.py index 6a8af35..b11770d 100644 --- a/model.py +++ b/model.py @@ -3,18 +3,19 @@ import torch.nn as nn from torch.utils.data import Dataset, DataLoader +def binary_reg(x: torch.Tensor): + # forward: f(x) = (x>=0) + # backward: f(x) = sigmoid + a = torch.sigmoid(x) + b = a.detach() + c = (x.detach() >= 0).float() + return a - b + c class HIN2vec(nn.Module): def __init__(self, node_size, path_size, embed_dim, sigmoid_reg=False, r=True): super().__init__() - # self.args = args - - def binary_reg(x: torch.Tensor): - raise NotImplementedError() - # return (x >= 0).float() # do not have gradients - self.reg = torch.sigmoid if sigmoid_reg else binary_reg self.__initialize_model(node_size, path_size, embed_dim, r) @@ -47,7 +48,6 @@ def forward(self, start_node: torch.LongTensor, end_node: torch.LongTensor, path return output - def train(log_interval, model, device, train_loader: DataLoader, optimizer, loss_function, epoch): model.train() for idx, (data, target) in enumerate(train_loader): @@ -102,3 +102,29 @@ def __getitem__(self, index): def __len__(self): return self.length +if __name__ == '__main__': + ## test binary_reg + + print('sigmoid') + a = torch.tensor([-1.,0.,1.],requires_grad=True) + b = torch.sigmoid(a) + c = b.sum() + print(a) + print(b) + print(c) + c.backward() + print(c.grad) + print(b.grad) + print(a.grad) + + print('binary') + a = torch.tensor([-1., 0., 1.], requires_grad=True) + b = binary_reg(a) + c = b.sum() + print(a) + print(b) + print(c) + c.backward() + print(c.grad) + print(b.grad) + print(a.grad)