Skip to content

Commit e847ec0

Browse files
authored
Merge pull request #1 from dansola/light_weight
Light weight
2 parents 283d984 + e76410f commit e847ec0

19 files changed

+759
-21
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ These proposals are incorporated into a Dataset class where a random index dicti
2626
Then PGA can be used on an image and its corresponding proposal dictionaries.
2727

2828
```python
29-
from models.basic_pga.basic_pga_parts import BlockPGA
29+
from src.models.basic_pga.basic_pga_parts import BlockPGA
3030
import torch
3131

3232
img = torch.rand(1,3,500,500) # test image

notebooks/__init__.py

Whitespace-only changes.

requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ kiwisolver==1.3.1
44
matplotlib==3.3.3
55
olefile==0.46
66
opencv-python==4.5.1.48
7-
Pillow @ file:///tmp/build/80754af9/pillow_1609786786540/work
7+
Pillow==8.1.0
88
pyparsing==2.4.7
99
python-dateutil==2.8.1
10-
six @ file:///tmp/build/80754af9/six_1605205327372/work
10+
six==1.15.0
1111
torch==1.7.1
1212
torchvision==0.8.2
13-
typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1598376058250/work
13+
typing-extensions
1414
numpy~=1.19.2
1515
pillow~=8.1.0
1616
wandb~=0.10.15

setup.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python
2+
import os
3+
import sys
4+
5+
from setuptools import find_packages, setup
6+
from setuptools.command.install import install
7+
8+
from src import VERSION
9+
10+
11+
class VerifyVersionCommand(install):
12+
"""Custom command to verify that the git tag matches our version"""
13+
14+
description = "verify that the package git tag matches our version"
15+
16+
def run(self):
17+
tag = os.getenv("COMMIT_TAG")
18+
19+
if tag != VERSION:
20+
info = "Git tag: {0} does not match the version of this app: {1}".format(
21+
tag, VERSION
22+
)
23+
sys.exit(info)
24+
25+
26+
with open("requirements.txt") as f:
27+
DEPENDENCIES = f.read().splitlines()
28+
29+
setup(
30+
name="src",
31+
packages=find_packages(),
32+
version=VERSION,
33+
description="Light weight semantic segmentation.",
34+
author="Daniel Sola",
35+
license="MIT",
36+
install_requires=DEPENDENCIES,
37+
python_requires=">=3.8",
38+
url="https://github.com/dansola/PGA-Net",
39+
cmdclass={"verify": VerifyVersionCommand},
40+
)

src/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VERSION = "0.0.1"

src/datasets/city.py

+268
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import cv2
5+
from PIL import Image
6+
import skimage.transform
7+
8+
from torch.utils import data
9+
10+
from datasets.utils import recursive_glob, Compose, RandomHorizontallyFlip, RandomRotate, Scale
11+
12+
13+
class City(data.Dataset):
14+
"""cityscapesLoader
15+
https://www.cityscapes-dataset.com
16+
Data is derived from CityScapes, and can be downloaded from here:
17+
https://www.cityscapes-dataset.com/downloads/
18+
Many Thanks to @fvisin for the loader repo:
19+
https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py
20+
"""
21+
22+
colors = [ # [ 0, 0, 0],
23+
[128, 64, 128],
24+
[244, 35, 232],
25+
[70, 70, 70],
26+
[102, 102, 156],
27+
[190, 153, 153],
28+
[153, 153, 153],
29+
[250, 170, 30],
30+
[220, 220, 0],
31+
[107, 142, 35],
32+
[152, 251, 152],
33+
[0, 130, 180],
34+
[220, 20, 60],
35+
[255, 0, 0],
36+
[0, 0, 142],
37+
[0, 0, 70],
38+
[0, 60, 100],
39+
[0, 80, 100],
40+
[0, 0, 230],
41+
[119, 11, 32],
42+
]
43+
44+
label_colours = dict(zip(range(19), colors))
45+
46+
mean_rgb = {
47+
"pascal": [103.939, 116.779, 123.68],
48+
"cityscapes": [0.0, 0.0, 0.0],
49+
} # pascal mean for PSPNet and ICNet pre-trained model
50+
51+
def __init__(
52+
self,
53+
root,
54+
split="train",
55+
is_transform=False,
56+
img_size=(256, 512),
57+
augmentations=None,
58+
img_norm=True,
59+
version="cityscapes",
60+
test_mode=False,
61+
):
62+
"""__init__
63+
:param root:
64+
:param split:
65+
:param is_transform:
66+
:param img_size:
67+
:param augmentations
68+
"""
69+
self.root = root
70+
self.split = split
71+
self.is_transform = is_transform
72+
self.augmentations = augmentations
73+
self.img_norm = img_norm
74+
self.n_classes = 19
75+
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
76+
self.mean = np.array(self.mean_rgb[version])
77+
self.files = {}
78+
79+
self.images_base = os.path.join(self.root, "leftImg8bit", self.split)
80+
self.annotations_base = os.path.join(self.root, "gtFine", self.split)
81+
82+
self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png")
83+
84+
self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
85+
self.valid_classes = [
86+
7,
87+
8,
88+
11,
89+
12,
90+
13,
91+
17,
92+
19,
93+
20,
94+
21,
95+
22,
96+
23,
97+
24,
98+
25,
99+
26,
100+
27,
101+
28,
102+
31,
103+
32,
104+
33,
105+
]
106+
self.class_names = [
107+
"unlabelled",
108+
"road",
109+
"sidewalk",
110+
"building",
111+
"wall",
112+
"fence",
113+
"pole",
114+
"traffic_light",
115+
"traffic_sign",
116+
"vegetation",
117+
"terrain",
118+
"sky",
119+
"person",
120+
"rider",
121+
"car",
122+
"truck",
123+
"bus",
124+
"train",
125+
"motorcycle",
126+
"bicycle",
127+
]
128+
129+
self.ignore_index = 250
130+
self.class_map = dict(zip(self.valid_classes, range(19)))
131+
132+
if not self.files[split]:
133+
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
134+
135+
print("Found %d %s images" % (len(self.files[split]), split))
136+
137+
def __len__(self):
138+
"""__len__"""
139+
return len(self.files[self.split])
140+
141+
def __getitem__(self, index):
142+
"""__getitem__
143+
:param index:
144+
"""
145+
img_path = self.files[self.split][index].rstrip()
146+
lbl_path = os.path.join(
147+
self.annotations_base,
148+
img_path.split(os.sep)[-2],
149+
os.path.basename(img_path)[:-15] + "gtFine_labelIds.png",
150+
)
151+
152+
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
153+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
154+
img = np.float32(img)
155+
156+
lbl = cv2.imread(lbl_path, cv2.IMREAD_GRAYSCALE) # GRAY 1 channel ndarray with shape H * W
157+
lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))
158+
159+
# img = np.array(Image.open(img_path))
160+
# img = np.array(img, dtype=np.uint8)
161+
#
162+
# lbl = np.array(Image.open(lbl_path))
163+
# lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))
164+
165+
if self.augmentations is not None:
166+
img, lbl = self.augmentations(img, lbl)
167+
168+
if self.is_transform:
169+
img, lbl = self.transform(img, lbl)
170+
171+
return {
172+
'image': img,
173+
'mask': lbl
174+
}
175+
176+
def transform(self, img, lbl):
177+
"""transform
178+
:param img:
179+
:param lbl:
180+
"""
181+
img = skimage.transform.resize(img, (self.img_size[0], self.img_size[1]),
182+
mode='edge',
183+
anti_aliasing=False,
184+
anti_aliasing_sigma=None,
185+
preserve_range=True,
186+
order=0) # uint8 with RGB mode
187+
img = img[:, :, ::-1] # RGB -> BGR
188+
img = img.astype(np.float64)
189+
img -= self.mean
190+
if self.img_norm:
191+
# Resize scales images from 0 to 255, thus we need
192+
# to divide by 255.0
193+
img = img.astype(float) / 255.0
194+
# NHWC -> NCHW
195+
img = img.transpose(2, 0, 1)
196+
197+
classes = np.unique(lbl)
198+
lbl = lbl.astype(float)
199+
lbl = skimage.transform.resize(lbl, (self.img_size[0], self.img_size[1]),
200+
mode='edge',
201+
anti_aliasing=False,
202+
anti_aliasing_sigma=None,
203+
preserve_range=True,
204+
order=0)
205+
lbl = lbl.astype(int)
206+
207+
if not np.all(classes == np.unique(lbl)):
208+
print("WARN: resizing labels yielded fewer classes")
209+
210+
if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes):
211+
print("after det", classes, np.unique(lbl))
212+
raise ValueError("Segmentation map contained invalid class values")
213+
214+
img = torch.from_numpy(img).float()
215+
lbl = torch.from_numpy(lbl).long()
216+
217+
return img, lbl
218+
219+
def decode_segmap(self, temp):
220+
r = temp.copy()
221+
g = temp.copy()
222+
b = temp.copy()
223+
for l in range(0, self.n_classes):
224+
r[temp == l] = self.label_colours[l][0]
225+
g[temp == l] = self.label_colours[l][1]
226+
b[temp == l] = self.label_colours[l][2]
227+
228+
rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
229+
rgb[:, :, 0] = r / 255.0
230+
rgb[:, :, 1] = g / 255.0
231+
rgb[:, :, 2] = b / 255.0
232+
return rgb
233+
234+
def encode_segmap(self, mask):
235+
# Put all void classes to zero
236+
for _voidc in self.void_classes:
237+
mask[mask == _voidc] = self.ignore_index
238+
for _validc in self.valid_classes:
239+
mask[mask == _validc] = self.class_map[_validc]
240+
return mask
241+
242+
243+
if __name__ == "__main__":
244+
import matplotlib.pyplot as plt
245+
246+
augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)])
247+
248+
local_path = "/datasets01/cityscapes/112817/"
249+
dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations)
250+
bs = 4
251+
trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
252+
for i, data_samples in enumerate(trainloader):
253+
imgs, labels = data_samples
254+
import pdb
255+
256+
pdb.set_trace()
257+
imgs = imgs.numpy()[:, ::-1, :, :]
258+
imgs = np.transpose(imgs, [0, 2, 3, 1])
259+
f, axarr = plt.subplots(bs, 2)
260+
for j in range(bs):
261+
axarr[j][0].imshow(imgs[j])
262+
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
263+
plt.show()
264+
a = input()
265+
if a == "ex":
266+
break
267+
else:
268+
plt.close()

0 commit comments

Comments
 (0)