-
Notifications
You must be signed in to change notification settings - Fork 192
/
testVid.lua
315 lines (261 loc) · 12.9 KB
/
testVid.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
require 'torch'
require 'unsup'
require 'nn'
require 'image'
require 'paths'
require 'lib/AdaptiveInstanceNormalization'
require 'lib/utils'
local cmd = torch.CmdLine()
-- Basic options
cmd:option('-style', '',
'File path to the style image, or multiple style images separated by commas if you want to do style interpolation or spatial control')
cmd:option('-styleDir', '', 'Directory path to a batch of style images')
cmd:option('-content', '', 'File path to the content image')
cmd:option('-contentDir', '', 'Directory path to a batch of content images')
cmd:option('-vgg', 'models/vgg_normalised.t7', 'Path to the VGG network')
cmd:option('-decoder', 'models/decoder.t7', 'Path to the decoder')
-- Additional options
cmd:option('-contentSize', 512,
'New (minimum) size for the content image, keeping the original size if set to 0')
cmd:option('-styleSize', 512,
'New (minimum) size for the style image, keeping the original size if set to 0')
cmd:option('-crop', false, 'If true, center crop both content and style image before resizing')
cmd:option('-saveExt', 'jpg', 'The extension name of the output image')
cmd:option('-gpu', 0, 'Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1')
cmd:option('-outputDir', 'output', 'Directory to save the output image(s)')
cmd:option('-saveOriginal', false,
'If true, also save the original content and style images in the output directory')
-- Advanced options
cmd:option('-preserveColor', false, 'If true, preserve color of the content image')
cmd:option('-alpha', 1, 'The weight that controls the degree of stylization. Should be between 0 and 1')
cmd:option('-styleInterpWeights', '', 'The weight for blending the style of multiple style images')
cmd:option('-mask', '', 'Mask to apply spatial control, assume to be the path to a binary mask of the same size as content image')
opt = cmd:parse(arg)
print(opt)
if opt.gpu >= 0 then
require 'cudnn'
require 'cunn'
end
assert(opt.content ~= '' or opt.contentDir ~= '', 'Either --content or --contentDir should be given.')
assert(opt.style ~= '' or opt.styleDir ~= '', 'Either --style or --styleDir should be given.')
assert(opt.content == '' or opt.contentDir == '', '--content and --contentDir cannot both be given.')
assert(opt.style == '' or opt.styleDir == '', '--style and --styleDir cannot both be given.')
assert(paths.filep(opt.decoder), 'Decoder ' .. opt.decoder .. ' does not exist.')
vgg = torch.load(opt.vgg)
for i=53,32,-1 do
vgg:remove(i)
end
local adain = nn.AdaptiveInstanceNormalization(vgg:get(#vgg-1).nOutputPlane)
decoder = torch.load(opt.decoder)
if opt.gpu >= 0 then
cutorch.setDevice(opt.gpu+1)
vgg = cudnn.convert(vgg, cudnn):cuda()
-- vgg:cuda()
adain:cuda()
-- decoder = cudnn.convert(decoder, cudnn):cuda()
decoder:cuda()
else
vgg:float()
adain:float()
decoder:float()
end
local function styleTransfer(content, style)
if opt.gpu >= 0 then
content = content:cuda()
style = style:cuda()
else
content = content:float()
style = style:float()
end
styleFeature = vgg:forward(style):clone()
contentFeature = vgg:forward(content):clone()
if opt.mask ~= '' then -- spatial control
assert(styleFeature:size(1) == 2) -- expect two style images
local styleFeatureFG = styleFeature[1]
local styleFeatureBG = styleFeature[2]
local C, H, W = contentFeature:size(1), contentFeature:size(2), contentFeature:size(3)
local maskResized = image.scale(mask, W, H, 'simple')
local maskView = maskResized:view(-1)
local fgmask = torch.LongTensor(torch.find(maskView, 1)) -- foreground indices
local bgmask = torch.LongTensor(torch.find(maskView, 0)) -- background indices
local contentFeatureView = contentFeature:view(C, -1)
local contentFeatureFG = contentFeatureView:index(2, fgmask):view(C, fgmask:nElement(), 1) -- C * #fg
local contentFeatureBG = contentFeatureView:index(2, bgmask):view(C, bgmask:nElement(), 1) -- C * #bg
targetFeatureFG = adain:forward({contentFeatureFG, styleFeatureFG}):clone():squeeze()
targetFeatureBG = adain:forward({contentFeatureBG, styleFeatureBG}):squeeze()
targetFeature = contentFeatureView:clone():zero() -- C * (H*W)
targetFeature:indexCopy(2, fgmask ,targetFeatureFG)
targetFeature:indexCopy(2, bgmask ,targetFeatureBG)
targetFeature = targetFeature:viewAs(contentFeature)
elseif opt.styleInterpWeights ~= '' then -- style interpolation
assert(styleFeature:size(1) == #styleInterpWeights, '-styleInterpWeights and -style must have the same number of elements')
targetFeature = contentFeature:clone():zero()
for i=1,styleFeature:size(1) do
targetFeature:add(styleInterpWeights[i], adain:forward({contentFeature, styleFeature[i]}))
end
else
targetFeature = adain:forward({contentFeature, styleFeature})
end
targetFeature = targetFeature:squeeze()
targetFeature = opt.alpha * targetFeature + (1 - opt.alpha) * contentFeature
return decoder:forward(targetFeature)
end
local function styleTransferModified(content, styleFeature)
if opt.gpu >= 0 then
content = content:cuda()
else
content = content:float()
end
contentFeature = vgg:forward(content):clone()
if opt.mask ~= '' then -- spatial control
assert(styleFeature:size(1) == 2) -- expect two style images
local styleFeatureFG = styleFeature[1]
local styleFeatureBG = styleFeature[2]
local C, H, W = contentFeature:size(1), contentFeature:size(2), contentFeature:size(3)
local maskResized = image.scale(mask, W, H, 'simple')
local maskView = maskResized:view(-1)
local fgmask = torch.LongTensor(torch.find(maskView, 1)) -- foreground indices
local bgmask = torch.LongTensor(torch.find(maskView, 0)) -- background indices
local contentFeatureView = contentFeature:view(C, -1)
local contentFeatureFG = contentFeatureView:index(2, fgmask):view(C, fgmask:nElement(), 1) -- C * #fg
local contentFeatureBG = contentFeatureView:index(2, bgmask):view(C, bgmask:nElement(), 1) -- C * #bg
targetFeatureFG = adain:forward({contentFeatureFG, styleFeatureFG}):clone():squeeze()
targetFeatureBG = adain:forward({contentFeatureBG, styleFeatureBG}):squeeze()
targetFeature = contentFeatureView:clone():zero() -- C * (H*W)
targetFeature:indexCopy(2, fgmask ,targetFeatureFG)
targetFeature:indexCopy(2, bgmask ,targetFeatureBG)
targetFeature = targetFeature:viewAs(contentFeature)
elseif opt.styleInterpWeights ~= '' then -- style interpolation
assert(styleFeature:size(1) == #styleInterpWeights, '-styleInterpWeights and -style must have the same number of elements')
targetFeature = contentFeature:clone():zero()
for i=1,styleFeature:size(1) do
targetFeature:add(styleInterpWeights[i], adain:forward({contentFeature, styleFeature[i]}))
end
else
targetFeature = adain:forward({contentFeature, styleFeature})
end
targetFeature = targetFeature:squeeze()
targetFeature = opt.alpha * targetFeature + (1 - opt.alpha) * contentFeature
return decoder:forward(targetFeature)
end
print('Creating save folder at ' .. opt.outputDir)
paths.mkdir(opt.outputDir)
if opt.mask ~= '' then
mask = image.load(opt.mask, 1, 'float') -- binary mask
end
local contentPaths = {}
local stylePaths = {}
if opt.content ~= '' then -- use a single content image
table.insert(contentPaths, opt.content)
else -- use a batch of content images
assert(opt.contentDir ~= '', "Either opt.contentDir or opt.content should be non-empty!")
contentPaths = extractImageNamesRecursive(opt.contentDir)
end
if opt.style ~= '' then
style_image_list = opt.style:split(',')
if #style_image_list == 1 then
style_image_list = style_image_list[1]
end
table.insert(stylePaths, style_image_list)
else -- use a batch of style images
assert(opt.styleDir ~= '', "Either opt.styleDir or opt.style should be non-empty!")
stylePaths = extractImageNamesRecursive(opt.styleDir)
end
if opt.styleInterpWeights ~= '' then
styleInterpWeights = opt.styleInterpWeights:split(',')
local styleInterpWeightsSum = torch.Tensor(styleInterpWeights):sum()
for i=1,#styleInterpWeights do -- normalize weights so they sum to 1
styleInterpWeights[i] = styleInterpWeights[i] / styleInterpWeightsSum
end
end
local numContent = #contentPaths
local numStyle = #stylePaths
print("# Content images: " .. numContent)
print("# Style images: " .. numStyle)
if opt.preserveColor then
for i=1,numContent do
local contentPath = contentPaths[i]
local contentExt = paths.extname(contentPath)
local contentImg = image.load(contentPath, 3, 'float')
local contentName = paths.basename(contentPath, contentExt)
local contentImg = sizePreprocess(contentImg, opt.crop, opt.contentSize)
for j=1,numStyle do -- generate a transferred image for each (content, style) pair
local stylePath = stylePaths[j]
if type(stylePath) == "table" then -- style blending
styleImg = {}
styleName = ''
for s=1,#stylePath do
local style = image.load(stylePath[s], 3, 'float')
styleExt = paths.extname(stylePath[s])
styleName = styleName .. '_' .. paths.basename(stylePath[s], styleExt)
style = sizePreprocess(style, opt.crop, opt.styleSize)
style = coral(style, contentImg)
style = style:add_dummy()
table.insert(styleImg, style)
end
styleImg = torch.cat(styleImg, 1)
styleName = styleName:sub(2)
else
styleExt = paths.extname(stylePath)
styleImg = image.load(stylePath, 3, 'float')
styleImg = sizePreprocess(styleImg, opt.crop, opt.styleSize)
styleImg = coral(styleImg, contentImg)
styleName = paths.basename(stylePath, styleExt)
end
local output = styleTransfer(contentImg, styleImg)
local savePath = paths.concat(opt.outputDir, contentName .. '_stylized_' .. styleName .. '.' .. opt.saveExt)
print('Output image saved at: ' .. savePath)
image.save(savePath, output)
if opt.saveOriginal then
-- also save the original images
image.save(paths.concat(opt.outputDir, contentName .. '.' .. contentExt), contentImg)
image.save(paths.concat(opt.outputDir, styleName .. '.' .. styleExt), styleImg)
end
end
end
else
for j=1,numStyle do
local stylePath = stylePaths[j]
if type(stylePath) == "table" then -- style blending
styleImg = {}
styleName = ''
for s=1,#stylePath do
local style = image.load(stylePath[s], 3, 'float')
styleExt = paths.extname(stylePath[s])
styleName = styleName .. '_' .. paths.basename(stylePath[s], styleExt)
style = sizePreprocess(style, opt.crop, opt.styleSize)
style = style:add_dummy()
table.insert(styleImg, style)
end
styleImg = torch.cat(styleImg, 1)
styleName = styleName:sub(2)
else
styleExt = paths.extname(stylePath)
styleImg = image.load(stylePath, 3, 'float')
styleImg = sizePreprocess(styleImg, opt.crop, opt.styleSize)
styleName = paths.basename(stylePath, styleExt)
end
if opt.gpu >= 0 then
styleImg = styleImg:cuda()
else
styleImg = styleImg:float()
end
styleFeature = vgg:forward(styleImg):clone() -- calculate style feature only once to improve speed
for i=1,numContent do -- generate a transferred image for each (content, style) pair
local contentPath = contentPaths[i]
local contentExt = paths.extname(contentPath)
local contentImg = image.load(contentPath, 3, 'float')
local contentName = paths.basename(contentPath, contentExt)
local contentImg = sizePreprocess(contentImg, opt.crop, opt.contentSize)
local output = styleTransferModified(contentImg, styleFeature)
local savePath = paths.concat(opt.outputDir, contentName .. '_stylized_' .. styleName .. '.' .. opt.saveExt)
print('Output image saved at: ' .. savePath)
image.save(savePath, output)
if opt.outputDirOriginal then
-- also save the original images
image.save(paths.concat(opt.outputDir, contentName .. '.' .. contentExt), contentImg)
image.save(paths.concat(opt.outputDir, styleName .. '.' .. styleExt), styleImg)
end
end
end
end