1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .nn .functional as F
4
+ class Conv (nn .Module ):
5
+
6
+ def __init__ (self , c1 , c2 , k = 1 , s = 1 , p = None , g = 1 , act = True ): # ch_in, ch_out, kernel, stride, padding, groups
7
+ super ().__init__ ()
8
+ self .conv = nn .Conv2d (c1 , c2 , k , s , autopad (k , p ), groups = g , bias = False )
9
+ self .bn = nn .BatchNorm2d (c2 )
10
+ self .act = nn .SiLU () if act is True else (act if isinstance (act , nn .Module ) else nn .Identity ())
11
+
12
+ def forward (self , x ):
13
+ return self .act (self .bn (self .conv (x )))
14
+
15
+ def forward_fuse (self , x ):
16
+ return self .act (self .conv (x ))
17
+
18
+ class depthwise_separable_conv (nn .Module ):
19
+ # https://wingnim.tistory.com/104
20
+ def __init__ (self , nin , nout , kernels_per_layer = 2 , stride = 1 , padding = 1 , bias = True ):
21
+ super (depthwise_separable_conv , self ).__init__ ()
22
+ self .depthwise = nn .Conv2d (nin , nin * kernels_per_layer , kernel_size = kernels_per_layer , stride = stride , padding = padding , groups = nin , bias = bias )
23
+ self .pointwise = nn .Conv2d (nin * kernels_per_layer , nout , kernel_size = 1 ,bias = bias )
24
+ def forward (self , x ):
25
+ out = self .depthwise (x )
26
+ out = self .pointwise (out )
27
+ return out
28
+
29
+ def channel_shuffle (x , groups ):
30
+ # https://github.com/jaxony/ShuffleNet/blob/master/model.py
31
+ batchsize , num_channels , height , width = x .data .size ()
32
+
33
+ channels_per_group = num_channels // groups
34
+
35
+ # reshape
36
+ x = x .view (batchsize , groups ,
37
+ channels_per_group , height , width )
38
+
39
+ # transpose
40
+ # - contiguous() required if transpose() is used before view().
41
+ # See https://github.com/pytorch/pytorch/issues/764
42
+ x = torch .transpose (x , 1 , 2 ).contiguous ()
43
+
44
+ # flatten
45
+ x = x .view (batchsize , - 1 , height , width )
46
+
47
+ return x
48
+
49
+ class GSConv (nn .Module ):
50
+ # Slim-neck by GSConv
51
+ def __init__ (self , c1 , c2 , k = 1 , s = 1 , p = None , bias = True ): # ch_in, ch_out, kernel, stride, padding, groups
52
+ super (GSConv , self ).__init__ ()
53
+ self .conv = Conv (c1 , c2 // 2 , k , s , p )
54
+ self .dwconv = depthwise_separable_conv (c2 // 2 , c2 // 2 , 3 , 1 , 1 , bias = bias )
55
+ # self.shuf = nn.ChannelShuffle(c2)
56
+
57
+ def forward (self , x ):
58
+ x = self .conv (x )
59
+ xd = self .dwconv (x )
60
+ x = torch .cat ((x ,xd ),1 )
61
+ _ ,C ,_ ,_ = x .shape
62
+ x = channel_shuffle (x , C )
63
+ return x
64
+
65
+ class GSConvBottleNeck (nn .Module ):
66
+ # Slim-neck by GSConv
67
+ def __init__ (self , c1 , c2 , k = 1 , s = 1 , p = None , bias = True ): # ch_in, ch_out, kernel, stride, padding, groups
68
+ super (GSConvBottleNeck , self ).__init__ ()
69
+ self .GSconv0 = GSConv (c1 , c2 // 2 , k , s , p )
70
+ self .GSconv1 = GSConv (c2 // 2 , c2 , 3 , 1 , 1 , bias = bias )
71
+
72
+ def forward (self , x ):
73
+ x_res = self .GSconv0 (x )
74
+ x_res = self .GSconv1 (x_res )
75
+
76
+ return x + x_res
77
+
78
+ class VoVGSConv (nn .Module ):
79
+ # Slim-neck by GSConv
80
+ def __init__ (self , c1 , c2 , k = 1 , s = 1 , p = None , bias = True ): # ch_in, ch_out, kernel, stride, padding, groups
81
+ super (VoVGSConv , self ).__init__ ()
82
+ self .conv0 = nn .Conv2d (c1 , c1 // 2 , 3 , 1 , 1 , bias = bias )
83
+ self .GSconv0 = GSConv (c1 // 2 , c1 // 2 , 3 , 1 , 1 , bias = bias )
84
+ self .GSconv1 = GSConv (c1 // 2 , c1 // 2 , 3 , 1 , 1 , bias = bias )
85
+ self .conv2 = nn .Conv2d ((c1 // 2 )* 2 , c2 , k , s , p , bias = bias )
86
+
87
+ def forward (self , x ):
88
+ x = self .conv0 (x )
89
+ x_1 = self .GSconv0 (x )
90
+ x_1 = self .GSconv1 (x_1 )
91
+ x = torch .cat ((x ,x_1 ), 1 )
92
+ x = self .conv2 (x )
93
+ return x
0 commit comments