-
Notifications
You must be signed in to change notification settings - Fork 286
/
Copy pathutils.py
155 lines (119 loc) · 4.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import sys
try:
from urllib import urlretrieve
except ImportError:
from urllib.request import urlretrieve
import torch
import torch.nn as nn
import torch.optim
def download_url(url, model_dir="~/.torch/proxyless_nas", overwrite=False):
model_dir = os.path.expanduser(model_dir)
filename = url.split('/')[-1]
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file) or overwrite:
os.makedirs(model_dir, exist_ok=True)
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
return cached_file
def load_url(url, model_dir='~/.torch/proxyless_nas', map_location=None):
cached_file = download_url(url, model_dir)
map_location = "cpu" if not torch.cuda.is_available() and map_location is None else None
return torch.load(cached_file, map_location=map_location)
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
logsoftmax = nn.LogSoftmax()
n_classes = pred.size(1)
# convert to one-hot
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros_like(pred)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * \
(1 - label_smoothing) + label_smoothing / n_classes
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def get_same_padding(kernel_size):
if isinstance(kernel_size, tuple):
assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
p1 = get_same_padding(kernel_size[0])
p2 = get_same_padding(kernel_size[1])
return p1, p2
assert isinstance(
kernel_size, int), 'kernel size should be either `int` or `tuple`'
assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2
def shuffle_layer(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
# transpose
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
def get_split_list(in_dim, child_num):
in_dim_list = [in_dim // child_num] * child_num
for _i in range(in_dim % child_num):
in_dim_list[_i] += 1
return in_dim_list
def list_sum(x):
if len(x) == 1:
return x[0]
else:
return x[0] + list_sum(x[1:])
def count_parameters(model):
total_params = sum(p.numel()
for p in model.parameters() if p.requires_grad)
return total_params
def count_conv_flop(layer, x):
out_h = int(x.size()[2] / layer.stride[0])
out_w = int(x.size()[3] / layer.stride[1])
delta_ops = layer.in_channels * layer.out_channels * \
layer.kernel_size[0] * layer.kernel_size[1] * out_h * out_w / layer.groups
return delta_ops
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""
Computes and stores the average and current value
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class BasicUnit(nn.Module):
def forward(self, x):
raise NotImplementedError
@property
def unit_str(self):
raise NotImplementedError
@property
def config(self):
raise NotImplementedError
@staticmethod
def build_from_config(config):
raise NotImplementedError
def get_flops(self, x):
raise NotImplementedError