-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_routing.py
146 lines (109 loc) · 4.03 KB
/
test_routing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import torch
from skimage import io
import numpy as np
from tqdm import tqdm
from utils.loading import load_config
from utils.setup import *
from modules.routing import ConfidenceRouting
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=False)
args = parser.parse_args()
return vars(args)
def prepare_input_data(batch, config, device):
for k, sensor_ in enumerate(config.DATA.input):
if k == 0:
inputs = batch[sensor_ + "_depth"].unsqueeze_(1)
else:
inputs = torch.cat((batch[sensor_ + "_depth"].unsqueeze_(1), inputs), 1)
inputs = inputs.to(device)
if config.ROUTING.intensity_grad:
intensity = batch["intensity"].unsqueeze_(1)
grad = batch["gradient"].unsqueeze_(1)
inputs = torch.cat((intensity, grad, inputs), 1)
inputs = inputs.to(device)
target = batch[config.DATA.target] # (batch size, height, width)
target = target.to(device)
target = target.unsqueeze_(1) # (batch size, channels, height, width)
return inputs, target
def test(config):
if config.SETTINGS.gpu:
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
# get test dataset
test_data_config = get_data_config(config, mode="test")
test_dataset = get_data(config.DATA.dataset, test_data_config)
test_loader = torch.utils.data.DataLoader(
test_dataset, config.TESTING.test_batch_size, config.TESTING.test_shuffle
)
# define model
Cin = len(config.DATA.input)
if config.ROUTING.intensity_grad:
Cin += 2
model = ConfidenceRouting(
Cin=Cin, F=config.MODEL.contraction, batchnorms=config.MODEL.normalization
)
# load model
checkpoint = torch.load(config.TESTING.model_path)
model.load_state_dict(checkpoint["pipeline_state_dict"])
model = model.to(device)
n_test_batches = int(len(test_dataset) / config.TESTING.test_batch_size)
for i, batch in enumerate(tqdm(test_loader, total=n_test_batches)):
inputs, target = prepare_input_data(batch, config, device)
output = model.forward(inputs)
est = output[:, 0, :, :].unsqueeze_(1)
unc = output[:, 1, :, :].unsqueeze_(1)
est = est.detach().cpu().numpy()
est = est.squeeze()
estplot = est
est = est * 1000
est = est.astype("uint16")
unc = unc.detach().cpu().numpy()
unc = (
unc.squeeze()
) # there is a relu activation function as the last step of the confidence decoder s.t. we always get non-negative numbers
confidence = np.exp(-1.0 * unc)
confidence *= 10000
confidence = confidence.astype("uint16")
output_dir_refined = (
config.DATA.root_dir
+ "/"
+ batch["frame_id"][0].split("/")[0]
+ "/"
+ batch["frame_id"][0].split("/")[1]
+ "/left_routing_refined_"
+ config.TESTING.model_path.split("/")[-3]
)
output_dir_confidence = (
config.DATA.root_dir
+ "/"
+ batch["frame_id"][0].split("/")[0]
+ "/"
+ batch["frame_id"][0].split("/")[1]
+ "/left_routing_confidence_"
+ config.TESTING.model_path.split("/")[-3]
)
if not os.path.exists(output_dir_refined):
os.makedirs(output_dir_refined)
if not os.path.exists(output_dir_confidence):
os.makedirs(output_dir_confidence)
io.imsave(
output_dir_refined + "/" + batch["frame_id"][0].split("/")[-1] + ".png", est
)
io.imsave(
output_dir_confidence + "/" + batch["frame_id"][0].split("/")[-1] + ".png",
confidence,
)
if __name__ == "__main__":
# get arguments
args = arg_parser()
# get configs
# load config
if args["config"]:
config = load_config(args["config"])
else:
raise ValueError("Missing configuration: Please specify config.")
# train
test(config)