Skip to content

Commit 446965e

Browse files
Add files via upload
1 parent 7c4385c commit 446965e

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

load_ds_histogram.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import torch
2+
import torchvision
3+
import torchvision.datasets as datasets
4+
from torchvision.transforms import v2
5+
from torch.utils.data import DataLoader
6+
from torchvision.datasets.folder import default_loader
7+
import cv2
8+
import numpy as np
9+
import os
10+
11+
NUM_WORKERS = os.cpu_count()
12+
13+
def create_dataloaders(train_path: str,
14+
test_path: str,
15+
batch_size: int,
16+
pre_proc_type: str,
17+
num_workers: int = NUM_WORKERS):
18+
19+
# Importing the datasets with imageFolder
20+
train_ds = HistogramDataset(train_path, pre_proc_type)
21+
test_ds = HistogramDataset(test_path, pre_proc_type)
22+
23+
# Creating the dataloaders
24+
train_dataloader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=True, drop_last=False)
25+
test_dataloader = DataLoader(test_ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, drop_last=False)
26+
27+
classes = train_ds.classes
28+
29+
return train_dataloader, test_dataloader, classes
30+
31+
32+
class HistogramDataset(torchvision.datasets.ImageFolder):
33+
def __init__(self, root, preproc_type, loader=default_loader, is_valid_file=None):
34+
super(HistogramDataset, self).__init__(root=root, loader=loader, is_valid_file=is_valid_file)
35+
self.pre_proc_type = preproc_type
36+
37+
def __getitem__(self, index):
38+
image_path, target = self.samples[index]
39+
im = cv2.imread(image_path)
40+
41+
im_nonoise = cv2.GaussianBlur(im, (3, 3), 1)
42+
if(self.pre_proc_type == 'lab' or self.pre_proc_type=='rgb'):
43+
if(self.pre_proc_type == 'lab'):
44+
prep_image = (im_nonoise * 1. / 255).astype(np.float32)
45+
im_lab = cv2.cvtColor(prep_image, cv2.COLOR_BGR2LAB)
46+
hist = calc_hists(im_lab, self.pre_proc_type)
47+
48+
# Setting up a matrix
49+
hist = np.stack([h for h in hist], axis=-1)
50+
# hist = np.stack([h for h in hist], axis=-1)
51+
hist = np.squeeze(hist)
52+
53+
# Normalizing the vector with L2 normalization
54+
norm = np.linalg.norm(hist)
55+
norm_hist = hist / norm
56+
# you need to convert img from np.array to torch.tensor
57+
# this has to be done CAREFULLY!
58+
sample = torchvision.transforms.ToTensor()(norm_hist)
59+
return sample, target
60+
61+
62+
# Define a function to compute the histogram of the image (channel by channel)
63+
def calc_hists(img: np.ndarray, hist_type) -> list:
64+
"""
65+
Calculates the histogram of the image (channel by channel).
66+
67+
Args:
68+
img (numpy.ndarray): image to calculate the histogram
69+
70+
Returns:
71+
list: list of histograms
72+
"""
73+
74+
assert img.ndim == 3, "The image must have 3 dimensions: (Height,Width,Channels)"
75+
76+
ch_1 = img[..., 0]
77+
ch_2 = img[..., 1]
78+
ch_3 = img[..., 2]
79+
80+
# Color image
81+
if hist_type == 'rgb':
82+
# Get the 3 channels
83+
# Compute the histogram for each channel. Please, bear in mind that in the "Range" parameter, the upper bound is exclusive. So, for considering values in the range [0,255] we must pass [0,256]. https://docs.opencv.org/3.4/d8/dbc/tutorial_histogram_calculation.html
84+
blue_hist = cv2.calcHist([ch_1], [0], None, [16], [0, 256])
85+
red_hist = cv2.calcHist([ch_2], [0], None, [16], [0, 256])
86+
green_hist = cv2.calcHist([ch_3], [0], None, [16], [0, 256])
87+
88+
return [blue_hist, green_hist, red_hist]
89+
# Greyscale image
90+
elif hist_type == 'lab':
91+
92+
L_hist = cv2.calcHist([ch_1], [0], None, [16], [0, 100])
93+
a_hist = cv2.calcHist([ch_2], [0], None, [16], [-128, 128])
94+
b_hist = cv2.calcHist([ch_3], [0], None, [16], [-128, 128])
95+
96+
return [L_hist, a_hist, b_hist]
97+
else:
98+
raise Exception("The image must have either 1 (greyscale image) or 3 (color image) channels")
99+
100+

model_mlp.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from torch import nn
3+
4+
class MLP(torch.nn.Module):
5+
def __init__(self, in_feature, hidden_size, num_classes):
6+
super().__init__()
7+
self.layer_1 = nn.Linear(in_features=in_feature, out_features=hidden_size)
8+
self.layer_2 = nn.Linear(in_features=hidden_size, out_features=num_classes)
9+
10+
def forward(self, x):
11+
x = nn.Flatten()(x)
12+
x = self.layer_1(x)
13+
x = nn.functional.relu(x)
14+
x = self.layer_2(x)
15+
return x

trainer_hist.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch, torchvision
2+
from torch import nn
3+
from torchvision import transforms
4+
import matplotlib.pyplot as plt
5+
import evaluate
6+
from torch.utils.tensorboard import SummaryWriter
7+
import engine
8+
import load_ds_histogram, model_mlp
9+
import os
10+
import re
11+
from pathlib import Path
12+
import warnings
13+
import argparse
14+
import datetime
15+
16+
warnings.filterwarnings("ignore")
17+
18+
19+
parser = argparse.ArgumentParser()
20+
# These arguments will be set appropriately by ReCodEx, even if you change them.
21+
parser.add_argument("--batch_size", default=16, type=int, help="Batch size.")
22+
parser.add_argument('--epochs', default=30, type=int, help="Epochs.")
23+
parser.add_argument('--learning_rate', default=0.1, type=float, help="Learning rate.")
24+
parser.add_argument('--label_smoothing', default=0.1, type=float, help='Label smoothing.')
25+
parser.add_argument('--preproc_type', default='lab', type=str, choices=['lab', 'rgb'], help='Type of preprocessing')
26+
parser.add_argument('--hidden_size', default=32, type=int, help='Number of hidden neurons in the MLP')
27+
parser.add_argument('--only_inference', default=True, type=bool, help='Number of hidden neurons in the MLP')
28+
29+
def main(args: argparse.Namespace):
30+
args.logdir = os.path.join("logs", "{}-{}-{}".format(
31+
os.path.basename(globals().get("__file__", "notebook")),
32+
datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"),
33+
",".join(("{}={}".format(re.sub("(.)[^_]*_?", r"\1", k), v) for k, v in sorted(vars(args).items())))
34+
))
35+
36+
37+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
38+
39+
data_path = Path('syndrone')
40+
train_path = data_path / 'train'
41+
test_path = data_path / 'test'
42+
43+
train_dataloader, test_dataloader, classes = load_ds_histogram.create_dataloaders(train_path, test_path, args.batch_size, args.preproc_type)
44+
45+
mlp = model_mlp.MLP(16 * 3, args.hidden_size, len(classes))
46+
47+
loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
48+
optimizer = torch.optim.SGD(mlp.parameters(), lr=args.learning_rate)
49+
50+
writer = SummaryWriter(log_dir=args.logdir)
51+
52+
# Generate the metrics
53+
clf_metrics = {'precision': evaluate.load("precision"),
54+
'recall': evaluate.load("recall"),
55+
'f1': evaluate.load("f1"),
56+
'accuracy': evaluate.load("accuracy")}
57+
58+
59+
if (not args.only_inference):
60+
engine.train(mlp, train_dataloader, test_dataloader, optimizer, loss_fn, args.epochs, device, clf_metrics, 0, writer, 'mlp',args.logdir)
61+
62+
mlp.load_state_dict(torch.load(str(f'{args.logdir}/model.pth')))
63+
64+
# Now test on UAVid
65+
data_path = Path('UAVid')
66+
train_path = data_path / 'train'
67+
test_path = data_path / 'test'
68+
69+
train_dataloader, test_dataloader, classes = load_ds_histogram.create_dataloaders(train_path, test_path, args.batch_size, args.preproc_type)
70+
71+
results = engine.test_step(mlp, test_dataloader, loss_fn, device, clf_metrics)
72+
73+
for k, v in results.items():
74+
writer.add_scalar(f'test/{k}', v, 1)
75+
76+
77+
if __name__ == "__main__":
78+
args = parser.parse_args([] if "__file__" not in globals() else None)
79+
main(args)

0 commit comments

Comments
 (0)