-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
100 lines (75 loc) · 3.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Author: Daiwei (David) Lu
# Useful Utils
import torch
from skimage import io, transform
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torchvision import utils
import torchvision.transforms.functional as TF
class TestFile(Dataset):
def __init__(self, file, transform=None):
self.file = file
self.transform = transform
def __len__(self):
return 1
def __getitem__(self, idx):
sample = io.imread(self.file)
if self.transform:
sample = self.transform(sample)
return sample
class Rescale(object):
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
new_h, new_w = self.output_size, self.output_size
img = transform.resize(sample, (new_h, new_w))
return img
class Normalize(object):
def __init__(self, inplace=False):
self.mean = (0.5692824, 0.55365936, 0.5400631)
self.std = (0.1325967, 0.1339596, 0.14305606)
self.inplace = inplace
def __call__(self, sample):
image = sample
return {'image': TF.normalize(image, self.mean, self.std, self.inplace),
'original': image}
class ToTensor(object):
def __call__(self, sample):
dtype = torch.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
image = sample.transpose((2, 0, 1))
return torch.from_numpy(image).type(dtype)
def show_dot(image, coordinates):
plt.imshow(image)
plt.scatter(image.shape[1] * coordinates[0][0], image.shape[0] * coordinates[0][1], marker='.', c='r')
plt.pause(0.001)
def batch_show(sample_batched):
"""Show image for a batch of samples."""
images_batch, coordinates_batch = \
sample_batched['original'], sample_batched['coordinates']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(coordinates_batch[i, 0].cpu().numpy() * 256 + i * im_size + (i + 1) * grid_border_size,
coordinates_batch[i, 1].cpu().numpy() * 256 + grid_border_size,
marker='.', c='r')
plt.title('Batch from dataloader')
def visualize_model(model, dataloaders, device):
was_training = model.training
model.eval()
fig = plt.figure()
with torch.no_grad():
for i, batch in enumerate(dataloaders['val']):
inputs, labels = batch['image'], batch['coordinates']
inputs = inputs.float().cuda().to(device)
print('Label:', batch['coordinates'].data)
batch['coordinates'].data = model(inputs).data
plt.figure()
batch_show(batch)
plt.axis('off')
plt.ioff()
plt.show()
model.train(mode=was_training)