-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
83 lines (61 loc) · 3.05 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import glob
import random
from torch.utils.data import Dataset
from utils import k_fold, AwbAug, feature_select
class CcData(Dataset):
def __init__(self, path, train=True, fold_num=0, test_sensor=''):
self.path = path
self.train = train
self.illu_full = glob.glob(path + 'numpy_labels' + '/*.npy')
self.img_full = glob.glob(path + 'numpy_data' + '/*.npy')
self.img_full.sort(key=lambda x: x.split('\\')[-1].split('_')[-1].split('.')[0])
self.illu_full.sort(key=lambda x: x.split('\\')[-1].split('_')[-1].split('.')[0])
train_test = k_fold(n_splits=3, num=len(self.img_full))
img_idx = train_test['train' if self.train else 'test'][fold_num]
self.fold_data = [self.img_full[i] for i in img_idx]
self.fold_illu = [self.fold_data[i].replace('numpy_data', 'numpy_labels') for
i in range(len(self.fold_data))]
self.data_aug = AwbAug(self.illu_full, sensor_name=test_sensor)
def __len__(self):
return len(self.fold_data)
def __getitem__(self, idx):
""" Gets next data in the dataloader.
Note: We pre-processed the input data in the format of '.npy' for fast processing. If
you want to train your own dataset, the corresponding of loadig image should also be changed.
"""
img_data = np.load(self.fold_data[idx])
gd_data = np.load(self.fold_illu[idx])
# if self.train:
img_data, gd_data = self.data_aug.awb_aug(gd_data, img_data)
feature_data = feature_select(img_data)
return feature_data.astype(np.float32), gd_data.astype(np.float32)
class CcDataEval(Dataset):
"""
for evaluation
"""
def __init__(self, path, train=False, fold_num=0):
self.path = path
self.train = train
self.img_full = glob.glob(path + 'numpy_data' + '/*.npy')
self.img_full.sort(key=lambda x: x.split('\\')[-1].split('_')[-1].split('.')[0])
train_test = k_fold(n_splits=3, num=len(self.img_full))
img_idx = train_test['train' if self.train else 'test'][fold_num]
self.fold_data = [self.img_full[i] for i in img_idx]
self.fold_illu = [self.fold_data[i].replace('numpy_data', 'numpy_labels') for
i in range(len(self.fold_data))]
def __len__(self):
return len(self.fold_data)
def random_select(self, num=5):
return random.sample(self.fold_data, num)
def __getitem__(self, idx):
""" Gets next data in the dataloader.
Note: We pre-processed the input data in the format of '.npy' for fast processing. If
you want to train your own dataset, the corresponding of loadig image should also be changed.
"""
img_data = np.load(self.fold_data[idx])
gd_data = np.load(self.fold_illu[idx])
gd_data = gd_data / gd_data.sum()
img_name = self.fold_illu[idx].split('.npy')[0].split('/')[-1]
feature_data = feature_select(img_data)
return feature_data.astype(np.float32), gd_data.astype(np.float32), img_name