Skip to content

Commit

Permalink
Fix comet_ml callback and process_prediction (#170)
Browse files Browse the repository at this point in the history
* Fix comet_ml callback & plot minimum in loss curve

* Add fix for applying symmetry

* Add towncrier message

* Add fix for uncertainty training

* Remove hardcoded numbers in sampling

* Revert sampling default values

* Extend towncrier

* Delete print

* Fix function for uncertainty training
  • Loading branch information
FeGeyer authored May 8, 2024
1 parent b70fb97 commit 3e7e0f4
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 25 deletions.
3 changes: 3 additions & 0 deletions docs/changes/170.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- fix `comet_ml` callback
- update `process_prediction` with better if statement
- change hardcoded values for sampling
4 changes: 2 additions & 2 deletions radionets/dl_framework/architectures/res_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import nn

from radionets.dl_framework.model import GeneralELU, SRBlock
from radionets.dl_framework.model import GeneralRelu, SRBlock


class SRResNet(nn.Module):
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(self):

self.hardtanh = nn.Hardtanh(-pi, pi)
self.relu = nn.ReLU()
self.elu = GeneralELU(add=+(1 + 1e-7))
self.elu = GeneralRelu(sub=-1e-10)

def forward(self, x):
s = x.shape[-1]
Expand Down
62 changes: 46 additions & 16 deletions radionets/dl_framework/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
from radionets.dl_framework.utils import _maybe_item, get_ifft_torch
from radionets.evaluation.plotting import create_OrBu
from radionets.evaluation.utils import (
apply_normalization,
apply_symmetry,
check_vmin_vmax,
eval_model,
get_ifft,
get_images,
load_data,
load_pretrained_model,
make_axes_nice,
rescale_normalization,
)

OrBu = create_OrBu()
Expand All @@ -29,9 +32,7 @@ def __init__(self, name, test_data, plot_n_epochs, amp_phase, scale):
self.experiment = Experiment(project_name=name)
self.data_path = test_data
self.plot_epoch = plot_n_epochs
self.test_ds = load_data(
self.data_path, mode="test", fourier=True, source_list=False
)
self.test_ds = load_data(self.data_path, mode="test", fourier=True)
self.amp_phase = amp_phase
self.scale = scale
self.uncertainty = False
Expand All @@ -51,39 +52,61 @@ def after_validate(self):
)

def plot_test_pred(self):
img_test, img_true = get_images(self.test_ds, 1, rand=False)
img_test, img_true, _ = get_images(self.test_ds, 1, rand=False)
img_test = img_test.unsqueeze(0)
img_true = img_true.unsqueeze(0)
model = self.model
if self.learn.normalize.mode == "all":
norm_dict = {"all": 0}
img_test, norm_dict = apply_normalization(img_test, norm_dict)

with self.experiment.test():
with torch.no_grad():
pred = eval_model(img_test, model)
pred = rescale_normalization(pred, norm_dict)
if pred.shape[1] == 4:
self.uncertainty = True
pred = torch.stack((pred[:, 0, :], pred[:, 2, :]), dim=1)
images = {"pred": pred, "truth": img_true}
images = apply_symmetry(images)
pred = images["pred"]
img_true = images["truth"]

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 10))
lim_phase = check_vmin_vmax(img_true[1])
lim_phase = check_vmin_vmax(img_true[0, 1])
im1 = ax1.imshow(pred[0, 0], cmap="inferno")
if self.uncertainty:
im2 = ax2.imshow(pred[0, 2], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
else:
im2 = ax2.imshow(pred[0, 1], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
im3 = ax3.imshow(img_true[0], cmap="inferno")
im4 = ax4.imshow(img_true[1], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
make_axes_nice(fig, ax1, im1, "Amplitude")
make_axes_nice(fig, ax2, im2, "Phase", phase=True)
make_axes_nice(fig, ax3, im3, "Org. Amplitude")
make_axes_nice(fig, ax4, im4, "Org. Phase", phase=True)
im2 = ax2.imshow(pred[0, 1], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
im3 = ax3.imshow(img_true[0, 0], cmap="inferno")
im4 = ax4.imshow(img_true[0, 1], cmap=OrBu, vmin=-lim_phase, vmax=lim_phase)
make_axes_nice(fig, ax1, im1, "Real")
make_axes_nice(fig, ax2, im2, "Imaginary")
make_axes_nice(fig, ax3, im3, "Org. Real")
make_axes_nice(fig, ax4, im4, "Org. Imaginary")
fig.tight_layout(pad=0.1)
self.experiment.log_figure(
figure=fig, figure_name=f"{self.epoch + 1}_pred_epoch"
)
plt.close("all")

def plot_test_fft(self):
img_test, img_true = get_images(self.test_ds, 1, rand=False)
img_test, img_true, _ = get_images(self.test_ds, 1, rand=False)
img_test = img_test.unsqueeze(0)
img_true = img_true.unsqueeze(0)
model = self.model
if self.learn.normalize.mode == "all":
norm_dict = {"all": 0}
img_test, norm_dict = apply_normalization(img_test, norm_dict)

with self.experiment.test():
with torch.no_grad():
pred = eval_model(img_test, model)
pred = rescale_normalization(pred, norm_dict)
if self.uncertainty:
pred = torch.stack((pred[:, 0, :], pred[:, 2, :]), dim=1)
images = {"pred": pred, "truth": img_true}
images = apply_symmetry(images)
pred = images["pred"]
img_true = images["truth"]

ifft_pred = get_ifft_torch(
pred,
Expand Down Expand Up @@ -150,8 +173,15 @@ def after_batch(self):
self.lrs.append(self.opt.hypers[-1]["lr"])

def plot_loss(self):
min_epoch = np.argmin(self.loss_valid)
plt.plot(self.loss_train, label="Training loss")
plt.plot(self.loss_valid, label="Validation loss")
plt.axvline(
min_epoch,
color="black",
linestyle="dashed",
label=f"Minimum at Epoch {min_epoch}",
)
plt.xlabel(r"Number of Epochs")
plt.ylabel(r"Loss")
plt.legend()
Expand Down
2 changes: 1 addition & 1 deletion radionets/dl_framework/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, leak=None, sub=None, maxv=None):
def forward(self, x):
x = F.leaky_relu(x, self.leak) if self.leak is not None else F.relu(x)
if self.sub is not None:
x.sub_(self.sub)
x = x - self.sub
if self.maxv is not None:
x.clamp_max_(self.maxv)
return x
Expand Down
7 changes: 6 additions & 1 deletion radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,12 @@ def save_sampled(conf):
result = sample_images(img["pred"], img["unc"], 100, conf)

# pad true image
output = F.pad(input=img["true"], pad=(0, 0, 0, 63), mode="constant", value=0)
output = F.pad(
input=img["true"],
pad=(0, 0, 0, img_size // 2 - 1),
mode="constant",
value=0,
)
img["true"] = symmetry(output, None)
ifft_truth = get_ifft(img["true"], amp_phase=conf["amp_phase"])

Expand Down
13 changes: 8 additions & 5 deletions radionets/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,15 +580,15 @@ def sample_images(mean, std, num_samples, conf):
sig=std_amp,
mode=mode[0],
num_samples=num_samples,
).reshape(num_img * num_samples, 65, 128)
).reshape(num_img * num_samples, mean_amp.shape[-2], mean_amp.shape[-1])

# phase
sampled_gauss_phase = trunc_rvs(
mu=mean_phase,
sig=std_phase,
mode=mode[1],
num_samples=num_samples,
).reshape(num_img * num_samples, 65, 128)
).reshape(num_img * num_samples, mean_phase.shape[-2], mean_phase.shape[-1])

# masks
if conf["amp_phase"]:
Expand All @@ -604,13 +604,16 @@ def sample_images(mean, std, num_samples, conf):

# pad resulting images and utilize symmetry
sampled_gauss = F.pad(
input=torch.tensor(sampled_gauss), pad=(0, 0, 0, 63), mode="constant", value=0
input=torch.tensor(sampled_gauss),
pad=(0, 0, 0, mean_amp.shape[-2] - 2),
mode="constant",
value=0,
)
sampled_gauss_symmetry = symmetry(sampled_gauss, None)

fft_sampled_symmetry = get_ifft(
sampled_gauss_symmetry, amp_phase=conf["amp_phase"], scale=False
).reshape(num_img, num_samples, 128, 128)
).reshape(num_img, num_samples, mean_amp.shape[-1], mean_amp.shape[-1])

results = {
"mean": fft_sampled_symmetry.mean(axis=1),
Expand Down Expand Up @@ -836,7 +839,7 @@ def process_prediction(conf, img_test, img_true, norm_dict, model, model_2):
pred = torch.cat((pred, pred_2), dim=1)

# apply symmetry
if pred.shape[-1] == 128:
if pred.shape[-2] < pred.shape[-1]:
img_dict = {"truth": img_true, "pred": pred}
img_dict = apply_symmetry(img_dict)
img_true = img_dict["truth"]
Expand Down

0 comments on commit 3e7e0f4

Please sign in to comment.