|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +import cv2 |
| 5 | +from PIL import Image |
| 6 | +import skimage.transform |
| 7 | + |
| 8 | +from torch.utils import data |
| 9 | + |
| 10 | +from datasets.utils import recursive_glob, Compose, RandomHorizontallyFlip, RandomRotate, Scale |
| 11 | + |
| 12 | + |
| 13 | +class City(data.Dataset): |
| 14 | + """cityscapesLoader |
| 15 | + https://www.cityscapes-dataset.com |
| 16 | + Data is derived from CityScapes, and can be downloaded from here: |
| 17 | + https://www.cityscapes-dataset.com/downloads/ |
| 18 | + Many Thanks to @fvisin for the loader repo: |
| 19 | + https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py |
| 20 | + """ |
| 21 | + |
| 22 | + colors = [ # [ 0, 0, 0], |
| 23 | + [128, 64, 128], |
| 24 | + [244, 35, 232], |
| 25 | + [70, 70, 70], |
| 26 | + [102, 102, 156], |
| 27 | + [190, 153, 153], |
| 28 | + [153, 153, 153], |
| 29 | + [250, 170, 30], |
| 30 | + [220, 220, 0], |
| 31 | + [107, 142, 35], |
| 32 | + [152, 251, 152], |
| 33 | + [0, 130, 180], |
| 34 | + [220, 20, 60], |
| 35 | + [255, 0, 0], |
| 36 | + [0, 0, 142], |
| 37 | + [0, 0, 70], |
| 38 | + [0, 60, 100], |
| 39 | + [0, 80, 100], |
| 40 | + [0, 0, 230], |
| 41 | + [119, 11, 32], |
| 42 | + ] |
| 43 | + |
| 44 | + label_colours = dict(zip(range(19), colors)) |
| 45 | + |
| 46 | + mean_rgb = { |
| 47 | + "pascal": [103.939, 116.779, 123.68], |
| 48 | + "cityscapes": [0.0, 0.0, 0.0], |
| 49 | + } # pascal mean for PSPNet and ICNet pre-trained model |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + root, |
| 54 | + split="train", |
| 55 | + is_transform=False, |
| 56 | + img_size=(256, 512), |
| 57 | + augmentations=None, |
| 58 | + img_norm=True, |
| 59 | + version="cityscapes", |
| 60 | + test_mode=False, |
| 61 | + ): |
| 62 | + """__init__ |
| 63 | + :param root: |
| 64 | + :param split: |
| 65 | + :param is_transform: |
| 66 | + :param img_size: |
| 67 | + :param augmentations |
| 68 | + """ |
| 69 | + self.root = root |
| 70 | + self.split = split |
| 71 | + self.is_transform = is_transform |
| 72 | + self.augmentations = augmentations |
| 73 | + self.img_norm = img_norm |
| 74 | + self.n_classes = 19 |
| 75 | + self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) |
| 76 | + self.mean = np.array(self.mean_rgb[version]) |
| 77 | + self.files = {} |
| 78 | + |
| 79 | + self.images_base = os.path.join(self.root, "leftImg8bit", self.split) |
| 80 | + self.annotations_base = os.path.join(self.root, "gtFine", self.split) |
| 81 | + |
| 82 | + self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png") |
| 83 | + |
| 84 | + self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] |
| 85 | + self.valid_classes = [ |
| 86 | + 7, |
| 87 | + 8, |
| 88 | + 11, |
| 89 | + 12, |
| 90 | + 13, |
| 91 | + 17, |
| 92 | + 19, |
| 93 | + 20, |
| 94 | + 21, |
| 95 | + 22, |
| 96 | + 23, |
| 97 | + 24, |
| 98 | + 25, |
| 99 | + 26, |
| 100 | + 27, |
| 101 | + 28, |
| 102 | + 31, |
| 103 | + 32, |
| 104 | + 33, |
| 105 | + ] |
| 106 | + self.class_names = [ |
| 107 | + "unlabelled", |
| 108 | + "road", |
| 109 | + "sidewalk", |
| 110 | + "building", |
| 111 | + "wall", |
| 112 | + "fence", |
| 113 | + "pole", |
| 114 | + "traffic_light", |
| 115 | + "traffic_sign", |
| 116 | + "vegetation", |
| 117 | + "terrain", |
| 118 | + "sky", |
| 119 | + "person", |
| 120 | + "rider", |
| 121 | + "car", |
| 122 | + "truck", |
| 123 | + "bus", |
| 124 | + "train", |
| 125 | + "motorcycle", |
| 126 | + "bicycle", |
| 127 | + ] |
| 128 | + |
| 129 | + self.ignore_index = 250 |
| 130 | + self.class_map = dict(zip(self.valid_classes, range(19))) |
| 131 | + |
| 132 | + if not self.files[split]: |
| 133 | + raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) |
| 134 | + |
| 135 | + print("Found %d %s images" % (len(self.files[split]), split)) |
| 136 | + |
| 137 | + def __len__(self): |
| 138 | + """__len__""" |
| 139 | + return len(self.files[self.split]) |
| 140 | + |
| 141 | + def __getitem__(self, index): |
| 142 | + """__getitem__ |
| 143 | + :param index: |
| 144 | + """ |
| 145 | + img_path = self.files[self.split][index].rstrip() |
| 146 | + lbl_path = os.path.join( |
| 147 | + self.annotations_base, |
| 148 | + img_path.split(os.sep)[-2], |
| 149 | + os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", |
| 150 | + ) |
| 151 | + |
| 152 | + img = cv2.imread(img_path, cv2.IMREAD_COLOR) |
| 153 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| 154 | + img = np.float32(img) |
| 155 | + |
| 156 | + lbl = cv2.imread(lbl_path, cv2.IMREAD_GRAYSCALE) # GRAY 1 channel ndarray with shape H * W |
| 157 | + lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) |
| 158 | + |
| 159 | + # img = np.array(Image.open(img_path)) |
| 160 | + # img = np.array(img, dtype=np.uint8) |
| 161 | + # |
| 162 | + # lbl = np.array(Image.open(lbl_path)) |
| 163 | + # lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) |
| 164 | + |
| 165 | + if self.augmentations is not None: |
| 166 | + img, lbl = self.augmentations(img, lbl) |
| 167 | + |
| 168 | + if self.is_transform: |
| 169 | + img, lbl = self.transform(img, lbl) |
| 170 | + |
| 171 | + return { |
| 172 | + 'image': img, |
| 173 | + 'mask': lbl |
| 174 | + } |
| 175 | + |
| 176 | + def transform(self, img, lbl): |
| 177 | + """transform |
| 178 | + :param img: |
| 179 | + :param lbl: |
| 180 | + """ |
| 181 | + img = skimage.transform.resize(img, (self.img_size[0], self.img_size[1]), |
| 182 | + mode='edge', |
| 183 | + anti_aliasing=False, |
| 184 | + anti_aliasing_sigma=None, |
| 185 | + preserve_range=True, |
| 186 | + order=0) # uint8 with RGB mode |
| 187 | + img = img[:, :, ::-1] # RGB -> BGR |
| 188 | + img = img.astype(np.float64) |
| 189 | + img -= self.mean |
| 190 | + if self.img_norm: |
| 191 | + # Resize scales images from 0 to 255, thus we need |
| 192 | + # to divide by 255.0 |
| 193 | + img = img.astype(float) / 255.0 |
| 194 | + # NHWC -> NCHW |
| 195 | + img = img.transpose(2, 0, 1) |
| 196 | + |
| 197 | + classes = np.unique(lbl) |
| 198 | + lbl = lbl.astype(float) |
| 199 | + lbl = skimage.transform.resize(lbl, (self.img_size[0], self.img_size[1]), |
| 200 | + mode='edge', |
| 201 | + anti_aliasing=False, |
| 202 | + anti_aliasing_sigma=None, |
| 203 | + preserve_range=True, |
| 204 | + order=0) |
| 205 | + lbl = lbl.astype(int) |
| 206 | + |
| 207 | + if not np.all(classes == np.unique(lbl)): |
| 208 | + print("WARN: resizing labels yielded fewer classes") |
| 209 | + |
| 210 | + if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): |
| 211 | + print("after det", classes, np.unique(lbl)) |
| 212 | + raise ValueError("Segmentation map contained invalid class values") |
| 213 | + |
| 214 | + img = torch.from_numpy(img).float() |
| 215 | + lbl = torch.from_numpy(lbl).long() |
| 216 | + |
| 217 | + return img, lbl |
| 218 | + |
| 219 | + def decode_segmap(self, temp): |
| 220 | + r = temp.copy() |
| 221 | + g = temp.copy() |
| 222 | + b = temp.copy() |
| 223 | + for l in range(0, self.n_classes): |
| 224 | + r[temp == l] = self.label_colours[l][0] |
| 225 | + g[temp == l] = self.label_colours[l][1] |
| 226 | + b[temp == l] = self.label_colours[l][2] |
| 227 | + |
| 228 | + rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) |
| 229 | + rgb[:, :, 0] = r / 255.0 |
| 230 | + rgb[:, :, 1] = g / 255.0 |
| 231 | + rgb[:, :, 2] = b / 255.0 |
| 232 | + return rgb |
| 233 | + |
| 234 | + def encode_segmap(self, mask): |
| 235 | + # Put all void classes to zero |
| 236 | + for _voidc in self.void_classes: |
| 237 | + mask[mask == _voidc] = self.ignore_index |
| 238 | + for _validc in self.valid_classes: |
| 239 | + mask[mask == _validc] = self.class_map[_validc] |
| 240 | + return mask |
| 241 | + |
| 242 | + |
| 243 | +if __name__ == "__main__": |
| 244 | + import matplotlib.pyplot as plt |
| 245 | + |
| 246 | + augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)]) |
| 247 | + |
| 248 | + local_path = "/datasets01/cityscapes/112817/" |
| 249 | + dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations) |
| 250 | + bs = 4 |
| 251 | + trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) |
| 252 | + for i, data_samples in enumerate(trainloader): |
| 253 | + imgs, labels = data_samples |
| 254 | + import pdb |
| 255 | + |
| 256 | + pdb.set_trace() |
| 257 | + imgs = imgs.numpy()[:, ::-1, :, :] |
| 258 | + imgs = np.transpose(imgs, [0, 2, 3, 1]) |
| 259 | + f, axarr = plt.subplots(bs, 2) |
| 260 | + for j in range(bs): |
| 261 | + axarr[j][0].imshow(imgs[j]) |
| 262 | + axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) |
| 263 | + plt.show() |
| 264 | + a = input() |
| 265 | + if a == "ex": |
| 266 | + break |
| 267 | + else: |
| 268 | + plt.close() |
0 commit comments