Skip to content

Commit

Permalink
Fixes for mypy type checking
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Kerfoot <[email protected]>
  • Loading branch information
ericspod committed Mar 7, 2025
1 parent db1d2af commit 160df23
Show file tree
Hide file tree
Showing 16 changed files with 99 additions and 99 deletions.
10 changes: 5 additions & 5 deletions monai/apps/detection/networks/retinanet_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(

for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]

self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
Expand Down Expand Up @@ -167,8 +167,8 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):

for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]

def forward(self, x: list[Tensor]) -> list[Tensor]:
"""
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
)
self.feature_extractor = feature_extractor

self.feature_map_channels: int = self.feature_extractor.out_channels
self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
self.num_anchors = num_anchors
self.classification_head = RetinaNetClassificationHead(
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/box_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,15 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor:

pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]

# When convert float32 to float16, Inf or Nan may occur
if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
raise ValueError("pred_whd_axis is NaN or Inf.")

# Distance from center to box's corner.
c_to_c_whd_axis = (
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
)

pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/reconstruction/networks/blocks/varnetblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor:
Returns:
Output of DC block with the same shape as x
"""
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore

def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
if isinstance(_affine, torch.Tensor):
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
else:
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
if suppress_zeros:
spacing[spacing == 0] = 1.0
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion monai/data/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_available_codecs() -> dict[str, str]:
for codec, ext in all_codecs.items():
writer = cv2.VideoWriter()
fname = os.path.join(tmp_dir, f"test{ext}")
fourcc = cv2.VideoWriter_fourcc(*codec)
fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
noviderr = writer.open(fname, fourcc, 1, (10, 10))
if noviderr:
codecs[codec] = ext
Expand Down
2 changes: 1 addition & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_nam
self.scheduler = scheduler

def get_target(self, images, noise, timesteps):
return self.scheduler.get_velocity(images, noise, timesteps)
return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]


def default_make_latent(
Expand Down
12 changes: 6 additions & 6 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def sample(
)

# 2. compute previous image: x_t -> x_t-1
image, _ = scheduler.step(model_output, t, image)
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
if save_intermediates and t % intermediate_steps == 0:
intermediates.append(image)
if save_intermediates:
Expand Down Expand Up @@ -986,8 +986,8 @@ def get_likelihood(
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image

# get the posterior mean and variance
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]

log_posterior_variance = torch.log(posterior_variance)
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
Expand Down Expand Up @@ -1436,7 +1436,7 @@ def sample( # type: ignore[override]
)

# 3. compute previous image: x_t -> x_t-1
image, _ = scheduler.step(model_output, t, image)
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
if save_intermediates and t % intermediate_steps == 0:
intermediates.append(image)
if save_intermediates:
Expand Down Expand Up @@ -1562,8 +1562,8 @@ def get_likelihood( # type: ignore[override]
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image

# get the posterior mean and variance
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]

log_posterior_variance = torch.log(posterior_variance)
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
Expand Down
26 changes: 13 additions & 13 deletions monai/inferers/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(
cropped_shape: Sequence[int] | None = None,
device: torch.device | str | None = None,
) -> None:
self.merged_shape = merged_shape
self.cropped_shape = self.merged_shape if cropped_shape is None else cropped_shape
self.merged_shape: tuple[int, ...] = tuple(merged_shape)
self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape)
self.device = device
self.is_finalized = False

Expand Down Expand Up @@ -231,9 +231,9 @@ def __init__(
dtype: np.dtype | str = "float32",
value_dtype: np.dtype | str = "float32",
count_dtype: np.dtype | str = "uint8",
store: zarr.storage.Store | str = "merged.zarr",
value_store: zarr.storage.Store | str | None = None,
count_store: zarr.storage.Store | str | None = None,
store: zarr.storage.Store | str = "merged.zarr", # type: ignore
value_store: zarr.storage.Store | str | None = None, # type: ignore
count_store: zarr.storage.Store | str | None = None, # type: ignore
compressor: str | None = None,
value_compressor: str | None = None,
count_compressor: str | None = None,
Expand All @@ -251,18 +251,18 @@ def __init__(
if version_geq(get_package_version("zarr"), "3.0.0"):
if value_store is None:
self.tmpdir = TemporaryDirectory()
self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
else:
self.value_store = value_store
self.value_store = value_store # type: ignore
if count_store is None:
self.tmpdir = TemporaryDirectory()
self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
self.count_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
else:
self.count_store = count_store
self.count_store = count_store # type: ignore
else:
self.tmpdir = None
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore
self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore
self.chunks = chunks
self.compressor = compressor
self.value_compressor = value_compressor
Expand Down Expand Up @@ -314,7 +314,7 @@ def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:
map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
with self.lock:
self.values[map_slice] += values.numpy()
self.counts[map_slice] += 1
self.counts[map_slice] += 1 # type: ignore[operator]

def finalize(self) -> zarr.Array:
"""
Expand All @@ -332,7 +332,7 @@ def finalize(self) -> zarr.Array:
if not self.is_finalized:
# use chunks for division to fit into memory
for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
self.output[chunk] = self.values[chunk] / self.counts[chunk]
self.output[chunk] = self.values[chunk] / self.counts[chunk] # type: ignore[operator]
# finalize the shape
self.output.resize(self.cropped_shape)
# set finalize flag to protect performing in-place division again
Expand Down
4 changes: 2 additions & 2 deletions monai/losses/sure_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Callable, Optional
from typing import Callable, Optional, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -92,7 +92,7 @@ def sure_loss_function(
y_ref = operator(x)

# get perturbed output
x_perturbed = x + eps * perturb_noise
x_perturbed = x + eps * perturb_noise # type: ignore
y_perturbed = operator(x_perturbed)
# divergence
divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/blocks/feature_pyramid_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@

from collections import OrderedDict
from collections.abc import Callable
from typing import cast

import torch
import torch.nn.functional as F
from torch import Tensor, nn

Expand Down Expand Up @@ -194,8 +196,8 @@ def __init__(
conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]
for m in self.modules():
if isinstance(m, conv_type_):
nn.init.kaiming_uniform_(m.weight, a=1)
nn.init.constant_(m.bias, 0.0)
nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)
nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)

if extra_blocks is not None:
if not isinstance(extra_blocks, ExtraFPNBlock):
Expand Down
93 changes: 45 additions & 48 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,53 +272,50 @@ def __init__(
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)

def load_from(self, weights):
layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment]
layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment]
layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment]
layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment]
wstate = weights["state_dict"]

with torch.no_grad():
self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
for bname, block in self.swinViT.layers1[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers1")
self.swinViT.layers1[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
)
self.swinViT.layers1[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
)
self.swinViT.layers1[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
)
for bname, block in self.swinViT.layers2[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers2")
self.swinViT.layers2[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
)
self.swinViT.layers2[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
)
self.swinViT.layers2[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
)
for bname, block in self.swinViT.layers3[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers3")
self.swinViT.layers3[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
)
self.swinViT.layers3[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
)
self.swinViT.layers3[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
)
for bname, block in self.swinViT.layers4[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers4")
self.swinViT.layers4[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
)
self.swinViT.layers4[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
)
self.swinViT.layers4[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
)
self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
for bname, block in layers1_0.blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator]

if layers1_0.downsample is not None:
d = layers1_0.downsample
d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore
d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore
d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore

for bname, block in layers2_0.blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator]

if layers2_0.downsample is not None:
d = layers2_0.downsample
d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore
d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore
d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore

for bname, block in layers3_0.blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator]

if layers3_0.downsample is not None:
d = layers3_0.downsample
d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore
d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore
d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore

for bname, block in layers4_0.blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator]

if layers4_0.downsample is not None:
d = layers4_0.downsample
d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore
d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore
d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore

@torch.jit.unused
def _check_input_size(self, spatial_shape):
Expand Down Expand Up @@ -532,7 +529,7 @@ def forward(self, x, mask):
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.clone()[:n, :n].reshape(-1)
self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator]
].reshape(n, n, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
Expand Down Expand Up @@ -691,7 +688,7 @@ def load_from(self, weights, n_block, layer):
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) # type: ignore[operator]
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, *inputs, **kwargs) -> None:

def init_bert_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # type: ignore[union-attr,arg-type]
elif isinstance(module, torch.nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
Expand Down
Loading

0 comments on commit 160df23

Please sign in to comment.