-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfind_mean_std.py
41 lines (36 loc) · 1.08 KB
/
find_mean_std.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Author: Daiwei (David) Lu
# Find the PhoneDataset mean and std for normalization
from load import *
from torch.utils.data import DataLoader
dataset = PhoneDataset('data/labels/train.txt',
'data',
mode='/train',
transform=transforms.Compose([
Rescale(256),
ToTensor(),
Normalize()
]))
loader = DataLoader(
dataset,
batch_size=10,
num_workers=1,
shuffle=False
)
pixel_mean = np.zeros(3)
pixel_std = np.zeros(3)
k = 1
for load in loader:
imgs = load['image']
imgs = np.array(imgs)
print(imgs.shape)
for i in range(imgs.shape[0]):
image = imgs[i]
pixels = image.reshape((-1, image.shape[2]))
for pixel in pixels:
diff = pixel - pixel_mean
pixel_mean += diff / k
pixel_std += diff * (pixel - pixel_mean)
k += 1
pixel_std = np.sqrt(pixel_std / (k - 2))
print(pixel_mean)
print(pixel_std)