Skip to content

Commit 6da0d63

Browse files
committed
Gsconv
1 parent 3c2eccc commit 6da0d63

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

Gsconv.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
class Conv(nn.Module):
5+
6+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
7+
super().__init__()
8+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
9+
self.bn = nn.BatchNorm2d(c2)
10+
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
11+
12+
def forward(self, x):
13+
return self.act(self.bn(self.conv(x)))
14+
15+
def forward_fuse(self, x):
16+
return self.act(self.conv(x))
17+
18+
class depthwise_separable_conv(nn.Module):
19+
# https://wingnim.tistory.com/104
20+
def __init__(self, nin, nout, kernels_per_layer=2, stride=1, padding=1, bias=True ):
21+
super(depthwise_separable_conv, self).__init__()
22+
self.depthwise = nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=kernels_per_layer, stride=stride, padding=padding, groups=nin, bias=bias)
23+
self.pointwise = nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1,bias=bias)
24+
def forward(self, x):
25+
out = self.depthwise(x)
26+
out = self.pointwise(out)
27+
return out
28+
29+
def channel_shuffle(x, groups):
30+
# https://github.com/jaxony/ShuffleNet/blob/master/model.py
31+
batchsize, num_channels, height, width = x.data.size()
32+
33+
channels_per_group = num_channels // groups
34+
35+
# reshape
36+
x = x.view(batchsize, groups,
37+
channels_per_group, height, width)
38+
39+
# transpose
40+
# - contiguous() required if transpose() is used before view().
41+
# See https://github.com/pytorch/pytorch/issues/764
42+
x = torch.transpose(x, 1, 2).contiguous()
43+
44+
# flatten
45+
x = x.view(batchsize, -1, height, width)
46+
47+
return x
48+
49+
class GSConv(nn.Module):
50+
# Slim-neck by GSConv
51+
def __init__(self, c1, c2, k=1, s=1, p=None, bias=True): # ch_in, ch_out, kernel, stride, padding, groups
52+
super(GSConv, self).__init__()
53+
self.conv = Conv(c1, c2//2, k, s, p)
54+
self.dwconv = depthwise_separable_conv(c2//2, c2//2, 3, 1, 1, bias=bias)
55+
# self.shuf = nn.ChannelShuffle(c2)
56+
57+
def forward(self, x):
58+
x = self.conv(x)
59+
xd = self.dwconv(x)
60+
x = torch.cat((x,xd),1)
61+
_,C,_,_ = x.shape
62+
x = channel_shuffle(x, C)
63+
return x
64+
65+
class GSConvBottleNeck(nn.Module):
66+
# Slim-neck by GSConv
67+
def __init__(self, c1, c2, k=1, s=1, p=None, bias=True): # ch_in, ch_out, kernel, stride, padding, groups
68+
super(GSConvBottleNeck, self).__init__()
69+
self.GSconv0 = GSConv(c1, c2//2, k, s, p)
70+
self.GSconv1 = GSConv(c2//2, c2, 3, 1, 1, bias=bias)
71+
72+
def forward(self, x):
73+
x_res = self.GSconv0(x)
74+
x_res = self.GSconv1(x_res)
75+
76+
return x+x_res
77+
78+
class VoVGSConv(nn.Module):
79+
# Slim-neck by GSConv
80+
def __init__(self, c1, c2, k=1, s=1, p=None, bias=True): # ch_in, ch_out, kernel, stride, padding, groups
81+
super(VoVGSConv, self).__init__()
82+
self.conv0 = nn.Conv2d(c1, c1//2, 3, 1, 1, bias=bias)
83+
self.GSconv0 = GSConv(c1//2, c1//2, 3, 1, 1, bias=bias)
84+
self.GSconv1 = GSConv(c1//2, c1//2, 3, 1, 1, bias=bias)
85+
self.conv2 = nn.Conv2d((c1//2)*2, c2, k, s, p, bias=bias)
86+
87+
def forward(self, x):
88+
x = self.conv0(x)
89+
x_1 = self.GSconv0(x)
90+
x_1 = self.GSconv1(x_1)
91+
x = torch.cat((x,x_1), 1)
92+
x = self.conv2(x)
93+
return x

0 commit comments

Comments
 (0)