forked from fangchangma/sparse-to-dense.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathnconv.py
137 lines (109 loc) · 4.28 KB
/
nconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
########################################
__author__ = "Abdelrahman Eldesokey"
__license__ = "GNU GPLv3"
__version__ = "0.1"
__maintainer__ = "Abdelrahman Eldesokey"
__email__ = "[email protected]"
########################################
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.conv import _ConvNd
import numpy as np
from scipy.stats import poisson
from scipy import signal
import math
# The proposed Normalized Convolution Layer
class NConv2d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, pos_fn='softplus', init_method='k', stride=1, padding=0, dilation=1, groups=1, bias=True):
# Call _ConvNd constructor
super(NConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, 0, groups, bias)
self.eps = 1e-20
self.pos_fn = pos_fn
self.init_method = init_method
# Initialize weights and bias
self.init_parameters()
if self.pos_fn is not None :
EnforcePos.apply(self, 'weight', pos_fn)
def forward(self, data, conf):
# Normalized Convolution
denom = F.conv2d(conf, self.weight, None, self.stride,
self.padding, self.dilation, self.groups)
nomin = F.conv2d(data*conf, self.weight, None, self.stride,
self.padding, self.dilation, self.groups)
nconv = nomin / (denom+self.eps)
# Add bias
b = self.bias
sz = b.size(0)
b = b.view(1,sz,1,1)
b = b.expand_as(nconv)
nconv += b
# Propagate confidence
cout = denom
sz = cout.size()
cout = cout.view(sz[0], sz[1], -1)
k = self.weight
k_sz = k.size()
k = k.view(k_sz[0], -1)
s = torch.sum(k, dim=-1, keepdim=True)
cout = cout / s
cout = cout.view(sz)
return nconv, cout
def init_parameters(self):
# Init weights
if self.init_method == 'x': # Xavier
torch.nn.init.xavier_uniform_(self.weight)
elif self.init_method == 'k': # Kaiming
torch.nn.init.kaiming_uniform_(self.weight)
elif self.init_method == 'n': # Normal dist
n = self.kernel_size[0] * self.kernel_size[1] * self.out_channels
self.weight.data.normal_(2, math.sqrt(2. / n))
elif self.init_method == 'p': # Poisson
mu=self.kernel_size[0]/2
dist = poisson(mu)
x = np.arange(0, self.kernel_size[0])
y = np.expand_dims(dist.pmf(x),1)
w = signal.convolve2d(y, y.transpose(), 'full')
w = torch.Tensor(w).type_as(self.weight)
w = torch.unsqueeze(w,0)
w = torch.unsqueeze(w,1)
w = w.repeat(self.out_channels, 1, 1, 1)
w = w.repeat(1, self.in_channels, 1, 1)
self.weight.data = w + torch.rand(w.shape)
else:
error('Undefined Initialization method!')
return
# Init bias
self.bias = torch.nn.Parameter(torch.zeros(self.out_channels)+0.01)
# Non-negativity enforcement class
class EnforcePos(object):
def __init__(self, pos_fn, name):
self.name = name
self.pos_fn = pos_fn
@staticmethod
def apply(module, name, pos_fn):
fn = EnforcePos(pos_fn, name)
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, inputs):
if module.training:
weight = getattr(module, self.name)
weight.data = self._pos(weight).data
else:
pass
def _pos(self, p):
pos_fn = self.pos_fn.lower()
if pos_fn == 'softmax':
p_sz = p.size()
p = p.view(p_sz[0],p_sz[1], -1)
p = F.softmax(p, -1)
return p.view(p_sz)
elif pos_fn == 'exp':
return torch.exp(p)
elif pos_fn == 'softplus':
return F.softplus(p, beta=10)
elif pos_fn == 'sigmoid':
return F.sigmoid(p)
else:
print('Undefined positive function!')
return