Skip to content

Commit

Permalink
add binary activation function
Browse files Browse the repository at this point in the history
  • Loading branch information
e-yi committed Jul 13, 2021
1 parent 9f12b34 commit 37b8e30
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

def binary_reg(x: torch.Tensor):
# forward: f(x) = (x>=0)
# backward: f(x) = sigmoid
a = torch.sigmoid(x)
b = a.detach()
c = (x.detach() >= 0).float()
return a - b + c

class HIN2vec(nn.Module):

def __init__(self, node_size, path_size, embed_dim, sigmoid_reg=False, r=True):
super().__init__()

# self.args = args

def binary_reg(x: torch.Tensor):
raise NotImplementedError()
# return (x >= 0).float() # do not have gradients

self.reg = torch.sigmoid if sigmoid_reg else binary_reg

self.__initialize_model(node_size, path_size, embed_dim, r)
Expand Down Expand Up @@ -47,7 +48,6 @@ def forward(self, start_node: torch.LongTensor, end_node: torch.LongTensor, path
return output



def train(log_interval, model, device, train_loader: DataLoader, optimizer, loss_function, epoch):
model.train()
for idx, (data, target) in enumerate(train_loader):
Expand Down Expand Up @@ -102,3 +102,29 @@ def __getitem__(self, index):
def __len__(self):
return self.length

if __name__ == '__main__':
## test binary_reg

print('sigmoid')
a = torch.tensor([-1.,0.,1.],requires_grad=True)
b = torch.sigmoid(a)
c = b.sum()
print(a)
print(b)
print(c)
c.backward()
print(c.grad)
print(b.grad)
print(a.grad)

print('binary')
a = torch.tensor([-1., 0., 1.], requires_grad=True)
b = binary_reg(a)
c = b.sum()
print(a)
print(b)
print(c)
c.backward()
print(c.grad)
print(b.grad)
print(a.grad)

0 comments on commit 37b8e30

Please sign in to comment.