-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathcsf_train.m
105 lines (92 loc) · 3.91 KB
/
csf_train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
%CSF_TRAIN - Train (cascades of) shrinkage fields.
% Author: Uwe Schmidt, TU Darmstadt ([email protected])
%
% This file is part of the implementation as described in the CVPR 2014 paper:
% Uwe Schmidt and Stefan Roth. Shrinkage Fields for Effective Image Restoration.
% Please see the file LICENSE.txt for the license governing this code.
function csf_train
%% OPTIONS
pw = false; % pairwise model
deblurring = false; % deblurring instead of denoising
static = struct;
% MODEL
static.nstages = 3;
static.unitnorm_filter = true;
if pw
static.fbank = 'pw';
static.fbank_sz = [];
static.do_filterlearning = false;
static.discard_bndry = 5;
else
static.fbank = 'dct';
static.fbank_sz = 3;
static.do_filterlearning = true;
static.discard_bndry = 2*static.fbank_sz;
end
% APPLICATION
if deblurring
static.is_deblurring = true;
static.sigma = 0.5;
else
static.is_deblurring = false;
static.sigma = 25;
end
% LEARNING
static.nimages = 5;
static.use_lut = true;
static.do_plot = true;
static.maxoptiters = 25;
static.do_joint_training = false;
if deblurring
static.ntapers = 3;
static.k_sz_max = 37; % largest kernel size in training set (assumed to be bigger than filter size)
static.discard_bndry = (static.k_sz_max-1)/2;
end
% constrain lambda to be positive:
% static.pos = train.pos_exp; % as used in the paper
static.pos = train.pos_expident; % typically works better
%% SETUP
static.imdims = [128,128];
static.npixels = prod(static.imdims);
static = train.get_filters(static);
static = train.precompute_model(static);
data = train.get_data('data',static);
data = train.precompute_data(data);
U = data.Y; % prediction from previous stage (initialize with observed data)
[static.shrink,static.THETA,learning] = train.init_params(static);
% note: recommended for joint training to use parameters from
% greedily trained models as initialization, e.g.:
% load('model.mat','learning');
% theta0 = arrayfun(@(t){misc.struct2vec(t)}, learning.THETA(2:end));
% theta0 = vertcat(theta0{:});
if static.do_plot, figure(1), clf, colormap(gray(256)), end
%% LEARNING
experiment = 'model.mat';
if static.do_joint_training
if ~exist('theta0','var'), theta0 = repmat(learning.theta(:,1),static.nstages,1); end
shrinkage = rbfmix.from_struct(repmat(static.shrink,static.nfilters,static.nstages));
cost_func = @(theta) train.objective_all_stages(reshape(theta,[],static.nstages), data, static, shrinkage);
minimizer = train.get_minimizer(static,cost_func,theta0);
learning.theta(:,2:end) = reshape(minimizer(),[],static.nstages);
for i = 1:static.nstages
learning.THETA(i+1) = misc.vec2struct(learning.theta(:,i+1),static.THETA);
[U,learning.psnrs(:,i+1)] = train.predict(U, learning.theta(:,i+1), data, static, shrinkage(:,i));
end
if ~isempty(experiment), save(experiment,'static','learning'); end
else
if ~exist('theta0','var'), theta0 = learning.theta(:,1); end
shrinkage = rbfmix.from_struct(repmat(static.shrink,static.nfilters,1));
cost_func = @(theta,U) train.objective_one_stage(U, theta, data, static, shrinkage);
minimizer = train.get_minimizer(static,cost_func,theta0);
for i = 1:static.nstages
learning.theta(:,i+1) = minimizer(U);
learning.THETA(i+1) = misc.vec2struct(learning.theta(:,i+1),static.THETA);
[U,learning.psnrs(:,i+1)] = train.predict(U, learning.theta(:,i+1), data, static, shrinkage);
if ~isempty(experiment), save(experiment,'static','learning'); end
end
end
fprintf('\nAvg. PSNR on training set:\n');
for i = 1:static.nstages
fprintf('- Stage %d: %.2fdB\n', i, mean(learning.psnrs(:,i+1),1));
end
end