-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathinference.py
82 lines (65 loc) · 2.62 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3.6
import argparse
import warnings
from typing import List
from pathlib import Path
import torch
import numpy as np
import torch.nn.functional as F
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader
from dataloader import SliceDataset
from utils import save_images, map_, tqdm_, probs2class, uniq
def runInference(args: argparse.Namespace):
print('>>> Loading model')
net = torch.load(args.model_weights)
device = torch.device("cuda")
net.to(device)
print('>>> Loading the data')
batch_size: int = args.batch_size
num_classes: int = args.num_classes
transform = transforms.Compose([
lambda img: np.array(img)[np.newaxis, ...],
lambda nd: nd / 255, # max <= 1
lambda nd: torch.tensor(nd, dtype=torch.float32)
])
folders: List[Path] = [Path(args.data_folder)]
names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
dt_set = SliceDataset(names,
folders,
transforms=[transform],
debug=False,
C=num_classes)
loader = DataLoader(dt_set,
batch_size=batch_size,
num_workers=batch_size + 2,
shuffle=False,
drop_last=False)
print('>>> Starting the inference')
savedir: str = args.save_folder
total_iteration = len(loader)
desc = f">> Inference"
tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
with torch.no_grad():
for j, (filenames, image, _) in tq_iter:
image = image.to(device)
pred_logits: Tensor = net(image)
pred_probs: Tensor = F.softmax(pred_logits, dim=1)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
predicted_class: Tensor = probs2class(pred_probs)
save_images(predicted_class, filenames, savedir, "", 0)
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Inference parameters')
parser.add_argument('--data_folder', type=str, required=True, help="The folder containing the images to predict")
parser.add_argument('--save_folder', type=str, required=True)
parser.add_argument('--model_weights', type=str, required=True)
parser.add_argument('--num_classes', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=10)
args = parser.parse_args()
print(args)
return args
if __name__ == '__main__':
args = get_args()
runInference(args)