Skip to content

Commit 19cdc1d

Browse files
committed
lstm inference
1 parent 4e31572 commit 19cdc1d

File tree

7 files changed

+251
-20
lines changed

7 files changed

+251
-20
lines changed

evals/eval.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -311,33 +311,33 @@ def eval_signal_results_from_h5(path):
311311
eval_signal_results_from_h5('eval_data.h5')
312312

313313
else:
314-
with h5py.File('../outputs/re_ep28.h5', 'r') as db:
314+
with h5py.File('../outputs/test_re_cnnlstm.h5', 'r') as db:
315315
keys = [key for key in db.keys()]
316316
print(keys)
317317

318318
refs = db['reference'][:]
319319
ref_list = np.empty(shape=(len(refs), 1))
320320
for i in range(len(refs)):
321-
ref_list[i] = mode(refs[i, :])[0]
321+
ref_list[i] = refs[i]
322322

323323
print(ref_list.shape)
324324
rates = db['rates'][:]
325325
print(rates.shape)
326326
signal = db['signal'][:]
327327

328-
with h5py.File('../outputs/re_ep93.h5') as db:
329-
rates2 = db['rates'][:]
330-
signal2 = db['signal'][:]
331-
332-
with h5py.File('../outputs/re_noncrop_ep93.h5') as db:
333-
rates3 = db['rates'][:]
328+
# with h5py.File('../outputs/re_ep93.h5') as db:
329+
# rates2 = db['rates'][:]
330+
# signal2 = db['signal'][:]
331+
#
332+
# with h5py.File('../outputs/re_noncrop_ep93.h5') as db:
333+
# rates3 = db['rates'][:]
334334
# signal2 = db['signal'][:]
335335

336-
# plt.figure()
337-
# plt.title('output of the first network')
338-
# plt.plot(signal, label='ep28')
336+
plt.figure()
337+
plt.title('output of the first network')
338+
plt.plot(signal)
339339
# plt.plot(signal2 + 4*np.std(signal2), label='ep93')
340-
# plt.show()
341-
eval_rate_results(ref_list, (rates, rates2, rates3))
340+
plt.show()
341+
eval_rate_results(ref_list, (rates, ))
342342

343343

evals/eval_rates.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import numpy as np
2+
from matplotlib import pyplot as plt
3+
import h5py
4+
from scipy.stats import pearsonr
5+
6+
7+
def eval_rate_results(ref, ests: tuple, sigs: tuple, labels:tuple):
8+
"""
9+
:param ref: reference pulse rate values
10+
:param ests: tuple of network probability estimates
11+
:param sigs: tuple of intermediate signals (output of PhysNet)
12+
"""
13+
assert len(ests) == len(sigs), 'Number of estimates must match number of signals!'
14+
n_est = len(ests)
15+
16+
# Plot signal waveform
17+
plt.figure(figsize=(12, 8))
18+
w = 1000
19+
for i in range(n_est):
20+
tmp = sigs[i][:].flatten()[:w]
21+
tmp = (tmp-np.mean(tmp))/np.std(tmp)
22+
plt.plot(tmp[:w]+(n_est-i)*5, label=labels[i])
23+
plt.xlabel('Time [h]')
24+
plt.title('Estimated signal form')
25+
plt.subplots_adjust(top=0.95, bottom=0.05)
26+
plt.legend()
27+
28+
N = len(ref)
29+
Fs = 20.
30+
n = 128
31+
T = n/Fs # length between two points in seconds
32+
t = np.linspace(0, N-1, N)*T/60./60. # time vector in hours
33+
ref = ref*60. # in BPM
34+
for i, est in enumerate(ests):
35+
est = est*60.
36+
plt.figure(figsize=(12, 6))
37+
plt.title(labels[i])
38+
plt.plot(t, ref, 'k', linewidth=2, label='reference')
39+
plt.plot(t, est[:, 0], 'r--', linewidth=1.5, label='mean estimate')
40+
plt.fill_between(t, est[:, 0] + est[:, 1], est[:, 0] - est[:, 1], color='r', alpha=0.2, label='confidence')
41+
plt.legend()
42+
plt.grid()
43+
plt.show()
44+
45+
est_list = np.empty((ests[0].shape[0], n_est), dtype=float)
46+
for i, est in enumerate(ests):
47+
est_list[:, i] = est[:, 0]*60. # use the expected value statistics
48+
49+
# remove 6.2 to 8 hours from arrays since these parts are corrupted
50+
# start_rmidx = int(6.2*60*60)
51+
# end_rmidx = int(8*60*60)
52+
# idxs2rm = [x for x in range(start_rmidx, end_rmidx)]
53+
# ref_list = np.delete(ref, idxs2rm, axis=0)
54+
# est_list = np.delete(est_list, idxs2rm, axis=0)
55+
56+
# Calculate metrics
57+
MAEs = np.mean(np.abs(np.subtract(ref, est_list)), axis=0)
58+
RMSEs = np.sqrt(np.mean(np.subtract(ref, est_list)**2, axis=0))
59+
MSEs = np.mean(np.subtract(ref, est_list)**2, axis=0)
60+
61+
rs = np.empty((n_est, 1), dtype=float)
62+
for count in range(n_est):
63+
rs[count] = pearsonr(ref.squeeze(), est_list[:, count])[0]
64+
65+
for i in range(n_est):
66+
print(f'\n({i})th statistics')
67+
print(f'MAE: {MAEs[i]}')
68+
print(f'RMSE: {RMSEs[i]}')
69+
print(f'MSE: {MSEs[i]}')
70+
print(f'Pearson r: {rs[i]}')
71+
print('-------------------------------------------------')
72+
73+
74+
if __name__ == '__main__':
75+
with h5py.File('../outputs/re_ep28.h5', 'r') as db:
76+
keys = [key for key in db.keys()]
77+
print(keys)
78+
79+
refs = db['reference'][:]
80+
ref_arr = np.empty(shape=(len(refs), 1))
81+
for i in range(len(refs)):
82+
ref_arr[i] = np.mean(refs[i])
83+
84+
print(f'Reference shape: {ref_arr.shape}')
85+
rates = db['rates'][:]
86+
print(rates.shape)
87+
88+
signal = db['signal'][:]
89+
90+
with h5py.File('../outputs/re_ep93.h5') as db:
91+
rates2 = db['rates'][:]
92+
signal2 = db['signal'][:]
93+
94+
with h5py.File('../outputs/re_noncrop_ep93.h5') as db:
95+
rates3 = db['rates'][:]
96+
signal3 = db['signal'][:]
97+
98+
eval_rate_results(ref_arr, ests=(rates, rates2, rates3), sigs=(signal, signal2, signal3),
99+
labels=('RateProbEst-crop-ep28', 'RateProbEst-crop-ep93', 'RateProbEst-full-ep93'))
100+

infer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def eval_model(models, testloader, criterion, oname):
2828

2929
with tr.no_grad():
3030
if len(models) == 1:
31-
outputs = model(*inputs).squeeze()
31+
outputs = models[0](*inputs).squeeze()
3232
# print(f'outputs.shape: {outputs.shape}')
3333

3434
if criterion is not None:
@@ -223,4 +223,4 @@ def eval_model(models, testloader, criterion, oname):
223223
# -------------------------------
224224
eval_model(models, testloader, criterion=loss_fn, oname=args.ofile_name)
225225

226-
print('Succefully finished!')
226+
print('Successfully finished!')

infer_lstm.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from src.archs import PhysNetED, RateProbLSTMCNN
2+
from src.dset import Dataset4DFromHDF5
3+
4+
import h5py
5+
import numpy as np
6+
import argparse
7+
from tqdm import tqdm
8+
9+
import torch
10+
from torch.utils.data import DataLoader
11+
tr = torch
12+
13+
14+
def eval_model(models, testloader, oname, device):
15+
total_loss = []
16+
result = []
17+
signal = []
18+
ref = []
19+
h1 = h2 = None
20+
21+
for inputs, targets in tqdm(testloader):
22+
with tr.no_grad():
23+
inputs = inputs.to(device)
24+
targets = targets.to(device)
25+
26+
# Signal extractor
27+
signals = models[0](inputs).view(-1, 1, 128)
28+
# Rate estimator
29+
rates, h1, h2 = models[1](signals, h1, h2)
30+
31+
rates = rates.view(-1, 2)
32+
targets = targets.squeeze()
33+
34+
# print(f'in inference targets.shape: {targets.shape}')
35+
# print(targets)
36+
result.extend(rates.data.cpu().numpy().tolist())
37+
signal.extend(signals.data.cpu().numpy().flatten().tolist())
38+
ref.extend(targets.data.cpu().numpy().reshape(-1, 1).tolist())
39+
40+
result = np.array(result)
41+
ref = np.array(ref)
42+
signal = np.array(signal)
43+
with h5py.File(f'outputs/{oname}.h5', 'w') as db:
44+
db.create_dataset('reference', shape=ref.shape, dtype=np.float32, data=ref)
45+
db.create_dataset('signal', shape=signal.shape, dtype=np.float32, data=signal)
46+
db.create_dataset('rates', shape=result.shape, dtype=np.float32, data=result)
47+
48+
print('Result saved!')
49+
50+
51+
if __name__ == '__main__':
52+
# train on the GPU or on the CPU, if a GPU is not available
53+
device_ = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
54+
print(device_)
55+
56+
parser = argparse.ArgumentParser()
57+
parser.add_argument('--data', type=str, help='path to benchmark .hdf5 file containing data')
58+
parser.add_argument('--interval', type=int, nargs='+',
59+
help='indices: val_start, val_end, shift_idx; if not given -> whole dataset')
60+
parser.add_argument("--weights", type=str, nargs='+', help="model weight paths")
61+
62+
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
63+
parser.add_argument("--ofile_name", type=str, help="output file name")
64+
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during generation')
65+
parser.add_argument('--crop', type=bool, default=False, help='crop baby with yolo (preprocessing step)')
66+
67+
args = parser.parse_args()
68+
start_idx = end_idx = None
69+
if args.interval:
70+
start_idx, end_idx = args.interval
71+
72+
# ---------------------------------------
73+
# Construct datasets
74+
# ---------------------------------------
75+
ref_type = 'PulseNumerical'
76+
testset = Dataset4DFromHDF5(args.data,
77+
labels=(ref_type,),
78+
device=torch.device('cpu'),
79+
start=args.interval[0], end=args.interval[1],
80+
crop=args.crop,
81+
augment=False,
82+
augment_freq=False
83+
)
84+
85+
testloader_ = DataLoader(testset,
86+
batch_size=args.batch_size,
87+
shuffle=False,
88+
num_workers=args.n_cpu,
89+
pin_memory=True)
90+
91+
# --------------------------
92+
# Load model
93+
# --------------------------
94+
models_ = [PhysNetED(), RateProbLSTMCNN()]
95+
96+
# ----------------------------------
97+
# Set up training
98+
# ---------------------------------
99+
for i in range(len(models_)):
100+
models_[i] = tr.nn.DataParallel(models_[i])
101+
models_[i].load_state_dict(tr.load(args.weights[i], map_location=device_))
102+
103+
# Use multiple GPU if there are!
104+
if torch.cuda.device_count() > 1:
105+
print("Let's use", torch.cuda.device_count(), "GPUs!")
106+
else:
107+
for i in range(len(models_)):
108+
models_[i] = models_[i].module
109+
110+
# Copy model to working device
111+
for i in range(len(models_)):
112+
models_[i] = models_[i].to(device_)
113+
114+
# -------------------------------
115+
# Evaluate model
116+
# -------------------------------
117+
eval_model(models_, testloader_, oname=args.ofile_name, device=device_)
118+
119+
print('Successfully finished!')

src/archs.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def __init__(self):
430430
)
431431

432432
self.end_part = nn.Sequential(
433-
nn.Dropout(0.5),
433+
nn.Dropout(0.3),
434434
nn.Conv1d(32, 16, kernel_size=1, stride=1, padding=0),
435435
nn.MaxPool1d(kernel_size=max_pool_kernel_size, stride=2, padding=2),
436436
)
@@ -457,11 +457,19 @@ def __init__(self):
457457
self.inception_block = InceptionBlock()
458458
self.cnn_block = CNNBlock()
459459

460-
self.lstm_layer1 = nn.LSTM(input_size=128, hidden_size=self.n_hid, num_layers=self.n_layers, dropout=0.3)
460+
self.lstm_layer1 = nn.LSTM(input_size=128, hidden_size=self.n_hid, num_layers=self.n_layers, dropout=0.2)
461461
self.lstm_layer2 = nn.LSTM(input_size=336, hidden_size=self.n_hid, num_layers=self.n_layers, dropout=0.5)
462462

463463
self.linear = nn.Linear(80, 2)
464464

465+
def init_hidden(self, bsz):
466+
"""
467+
Returns initial hidden state and hidden cell values
468+
"""
469+
weight = next(self.parameters())
470+
return (weight.new_zeros(self.n_layers, bsz, self.n_hid),
471+
weight.new_zeros(self.n_layers, bsz, self.n_hid))
472+
465473
def forward(self, x, h1=None, h2=None):
466474
# convolution stream
467475
x1 = self.inception_block(x)

train.cfg

+6-2
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,9 @@ python infer.py PhysNet RateProbEst \
2323
--weights checkpoints/rateextractor/model0_ep28.pt checkpoints/rateextractor/model1_ep28.pt \
2424
--ofile_name re_ep28.h5
2525

26-
# LSTM trainer setup
27-
DataSet -> batch_size > 1, augm=False, freq_augm=False, D=180
26+
# LSTM
27+
python train_lstm.py PhysNet CNNLSTM --loss Laplace --lr 1e-4 1e-3 \
28+
--data ../data/PUBLIC/benchmark_set/PIC191111_128x128_U8C3_fast.hdf5 --intervals 72000 180000 180000 216000 \
29+
--logger_name rateextractor --epochs 100 --epoch_start 19 \
30+
--pretrained_weights checkpoints/re_cnnlstm/model0_ep18.pt checkpoints/re_cnnlstm/model1_ep18.pt \
31+
--checkpoint_dir re_cnnlstm

train_lstm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def train_model(models, dataloaders, criterion, optimizers, schedulers, opath, n
6969
with experiment.test():
7070
experiment.log_metric("loss", epoch_loss, step=epoch+start_epoch)
7171
# Learning Rate scheduler (if epoch loss is on plato)
72-
schedulers[0].step(epoch_loss)
72+
schedulers[0].step(epoch_loss)
7373
schedulers[1].step(epoch_loss)
7474
else:
7575
train_loss_history.append(epoch_loss)

0 commit comments

Comments
 (0)