Skip to content

Commit

Permalink
Uncertainty fixes and altered architecture (#171)
Browse files Browse the repository at this point in the history
* Add query for `evaluate_msssim_sampled`

* Fix `evaluate_msssim_sampled`

* Fix binning for histograms

* Add no_grad architecture

* Towncrier

* Add bins as function argument

* Delete unnecessary conversion to ms ssim
  • Loading branch information
FeGeyer authored Aug 16, 2024
1 parent 3e7e0f4 commit 799a1aa
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 15 deletions.
2 changes: 2 additions & 0 deletions docs/changes/171.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- fix `evaluate_msssim_sampled`
- fix call of `evaluate_msssim_sampled`
1 change: 1 addition & 0 deletions docs/changes/171.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- set number of bins for histogram plotting
63 changes: 63 additions & 0 deletions radionets/dl_framework/architectures/res_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,66 @@ def forward(self, x):
x4 = self.elu(x4)

return torch.cat([x0, x3, x1, x4], dim=1)


class SRResNet_16_unc_no_grad(nn.Module):
def __init__(self):
super().__init__()

self.preBlock = nn.Sequential(
nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU()
)

# ResBlock 16
self.blocks = nn.Sequential(
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
SRBlock(64, 64),
)

self.postBlock = nn.Sequential(
nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(64),
)

self.final = nn.Sequential(nn.Conv2d(64, 4, 9, stride=1, padding=4, groups=2))

self.hardtanh = nn.Hardtanh(-pi, pi)
self.relu = nn.ReLU()
self.elu = GeneralRelu(sub=-1e-10)

def forward(self, x):
s = x.shape[-1]

x = self.preBlock(x)

x = x + self.postBlock(self.blocks(x))

x = self.final(x)

x0 = x[:, 0].reshape(-1, 1, s // 2 + 1, s)
# x0 = self.relu(x0)
x1 = x[:, 1].reshape(-1, 1, s // 2 + 1, s)
# x1 = self.hardtanh(x1)
x3 = x[:, 2].reshape(-1, 1, s // 2 + 1, s)
with torch.no_grad():
x3 = self.elu(x3)

x4 = x[:, 3].reshape(-1, 1, s // 2 + 1, s)
with torch.no_grad():
x4 = self.elu(x4)

return torch.cat([x0, x3, x1, x4], dim=1)
15 changes: 4 additions & 11 deletions radionets/evaluation/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,11 +628,10 @@ def plot_box(ax, num_boxes, corners):
)


def histogram_ms_ssim(msssim, out_path, plot_format="png"):
def histogram_ms_ssim(msssim, out_path, bins=30, plot_format="png"):
mean = np.mean(msssim)
std = np.std(msssim, ddof=1)
fig, (ax1) = plt.subplots(1, figsize=(6, 4))
bins = np.arange(msssim.min(), 1 + 0.01, 0.01)
ax1.hist(
msssim,
bins=bins,
Expand Down Expand Up @@ -664,12 +663,10 @@ def histogram_ms_ssim(msssim, out_path, plot_format="png"):
plt.savefig(outpath, bbox_inches="tight", pad_inches=0.01, dpi=150)


def histogram_sum_intensity(ratios_sum, out_path, plot_format="png"):
def histogram_sum_intensity(ratios_sum, out_path, bins=30, plot_format="png"):
fig, (ax1) = plt.subplots(1, figsize=(6, 4))
mean = np.mean(ratios_sum)
std = np.std(ratios_sum, ddof=1)
bins = np.arange(0.05, ratios_sum.max() + 0.05, 0.1)
bins = np.insert(bins, 0, 0)
ax1.hist(
ratios_sum,
bins=bins,
Expand Down Expand Up @@ -703,12 +700,10 @@ def histogram_sum_intensity(ratios_sum, out_path, plot_format="png"):
plt.savefig(outpath, bbox_inches="tight", pad_inches=0.01, dpi=150)


def histogram_peak_intensity(ratios_peak, out_path, plot_format="png"):
def histogram_peak_intensity(ratios_peak, out_path, bins=30, plot_format="png"):
fig, (ax1) = plt.subplots(1, figsize=(6, 4))
mean = np.mean(ratios_peak)
std = np.std(ratios_peak, ddof=1)
bins = np.arange(0.05, ratios_peak.max() + 0.05, 0.1)
bins = np.insert(bins, 0, 0)
ax1.hist(
ratios_peak,
bins=bins,
Expand Down Expand Up @@ -764,12 +759,10 @@ def histogram_mean_diff(vals, out_path, plot_format="png"):
plt.savefig(outpath, bbox_inches="tight", pad_inches=0.01, dpi=150)


def histogram_area(vals, out_path, plot_format="png"):
def histogram_area(vals, out_path, bins=30, plot_format="png"):
vals = vals.numpy()
mean = np.mean(vals)
std = np.std(vals, ddof=1)
bins = np.arange(0.05, np.round(vals.max()) + 0.05, 0.1)
bins = np.insert(bins, 0, 0)
fig, (ax1) = plt.subplots(1, figsize=(6, 4))
ax1.hist(
vals, bins=bins, color="darkorange", linewidth=3, histtype="step", alpha=0.75
Expand Down
7 changes: 6 additions & 1 deletion radionets/evaluation/scripts/start_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
evaluate_intensity_sampled,
evaluate_mean_diff,
evaluate_ms_ssim,
evaluate_ms_ssim_sampled,
evaluate_point,
evaluate_unc,
evaluate_viewing_angle,
Expand Down Expand Up @@ -103,7 +104,11 @@ def main(configuration_path):

if eval_conf["ms_ssim"]:
click.echo("\nStart evaluation of ms ssim.\n")
evaluate_ms_ssim(eval_conf)
samp_file = check_samp_file(eval_conf)
if samp_file:
evaluate_ms_ssim_sampled(eval_conf)
else:
evaluate_ms_ssim(eval_conf)

if eval_conf["intensity"]:
click.echo("\nStart evaluation of intensity.\n")
Expand Down
7 changes: 4 additions & 3 deletions radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ def evaluate_ms_ssim_sampled(conf):
out_path = Path(model_path).parent / "evaluation"
out_path.mkdir(parents=True, exist_ok=True)

data_path = str(out_path) + "/sampled_imgs.h5"
name_model = Path(model_path).stem
data_path = str(out_path) + f"/sampled_imgs_{name_model}.h5"
loader = create_sampled_databunch(data_path, conf["batch_size"])
vals = []

Expand All @@ -586,10 +587,10 @@ def evaluate_ms_ssim_sampled(conf):
click.echo(f"\nThe mean ms-ssim value is {vals.mean()}.\n")

if conf["save_vals"]:
click.echo("\nSaving area ratios.\n")
click.echo("\nSaving msssim ratios.\n")
out = Path(conf["save_path"])
out.mkdir(parents=True, exist_ok=True)
np.savetxt(out / "area_ratios.txt", vals)
np.savetxt(out / "msssim_ratios.txt", vals)


def evaluate_area_sampled(conf):
Expand Down

0 comments on commit 799a1aa

Please sign in to comment.