-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSRCNN.py
113 lines (89 loc) · 3.23 KB
/
SRCNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn
from torch.nn.functional import interpolate
class NewInterpolate(torch.autograd.Function):
@staticmethod
def symbolic(g, input, scales):
return g.op("Resize",
input,
g.op("Constant",
value_t=torch.tensor([], dtype=torch.float32)),
scales,
coordinate_transformation_mode_s="pytorch_half_pixel",
cubic_coeff_a_f=-0.75,
mode_s='cubic',
nearest_mode_s="floor")
@staticmethod
def forward(ctx, input, scales):
scales = scales.tolist()[-2:]
return interpolate(input,
scale_factor=scales,
mode='bicubic',
align_corners=False)
class SuperResolutionNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
self.relu = nn.ReLU()
def forward(self, x, upscale_factor):
# x = interpolate(x,
# scale_factor=upscale_factor.item(),
# mode='bicubic',
# align_corners=False)
x = NewInterpolate.apply(x, upscale_factor)
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return out
# Download checkpoint and test image
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth',
'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']
for url, name in zip(urls, names):
if not os.path.exists(name):
open(name, 'wb').write(requests.get(url).content)
def init_torch_model():
torch_model = SuperResolutionNet()
state_dict = torch.load('srcnn.pth')['state_dict']
# Adapt the checkpoint
for old_key in list(state_dict.keys()):
new_key = '.'.join(old_key.split('.')[1:])
state_dict[new_key] = state_dict.pop(old_key)
torch_model.load_state_dict(state_dict)
torch_model.eval()
return torch_model
model = init_torch_model()
# input_img = cv2.imread('face.png').astype(np.float32)
#
# # HWC to NCHW
# input_img = np.transpose(input_img, [2, 0, 1])
# input_img = np.expand_dims(input_img, 0)
#
# # Inference
# torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy()
#
# # NCHW to HWC
# torch_output = np.squeeze(torch_output, 0)
# torch_output = np.clip(torch_output, 0, 255)
# torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)
#
# # Show image
# cv2.imwrite("face_torch.png", torch_output)
"""torch to onnx"""
x = torch.randn(1, 3, 256, 256)
factor = torch.tensor([1,1,3,3], dtype=torch.float)
with torch.no_grad():
torch.onnx.export(
model,
(x, factor),
"srcnn3.onnx",
opset_version=11,
input_names=['input', 'factor'],
output_names=['output'])