diff --git a/docs/source/networks.rst b/docs/source/networks.rst index e2e509a99b..3c8ea725a9 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -109,6 +109,16 @@ Blocks .. autoclass:: SABlock :members: +`CABlock Block` +~~~~~~~~~~~~~~~ +.. autoclass:: CABlock + :members: + +`FeedForward Block` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FeedForward + :members: + `Squeeze-and-Excitation` ~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ChannelSELayer @@ -173,6 +183,16 @@ Blocks .. autoclass:: Subpixelupsample .. autoclass:: SubpixelUpSample +`Downsampling` +~~~~~~~~~~~~~~ +.. autoclass:: DownSample + :members: +.. autoclass:: Downsample +.. autoclass:: SubpixelDownsample + :members: +.. autoclass:: Subpixeldownsample +.. autoclass:: SubpixelDownSample + `Registration Residual Conv Block` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: RegistrationResidualConvBlock @@ -625,6 +645,11 @@ Nets .. autoclass:: ViT :members: +`Restormer` +~~~~~~~~~~~ +.. autoclass:: restormer + :members: + `ViTAutoEnc` ~~~~~~~~~~~~ .. autoclass:: ViTAutoEnc diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 499caf2e0f..22af82d316 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -15,12 +15,13 @@ from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish from .aspp import SimpleASPP from .backbone_fpn_utils import BackboneWithFPN +from .cablock import CABlock, FeedForward from .convolutions import Convolution, ResidualUnit from .crf import CRF from .crossattention import CrossAttentionBlock from .denseblock import ConvDenseBlock, DenseBlock from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock -from .downsample import MaxAvgPool +from .downsample import DownSample, Downsample, MaxAvgPool, SubpixelDownsample, SubpixelDownSample, Subpixeldownsample from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .encoder import BaseEncoder from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py new file mode 100644 index 0000000000..72e4cc68d0 --- /dev/null +++ b/monai/networks/blocks/cablock.py @@ -0,0 +1,180 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.convolutions import Convolution +from monai.utils import optional_import + +rearrange, _ = optional_import("einops", name="rearrange") + +__all__ = ["FeedForward", "CABlock"] + + +class FeedForward(nn.Module): + """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. + Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection. + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + dim: Number of input channels + ffn_expansion_factor: Factor to expand hidden features dimension + bias: Whether to use bias in convolution layers + """ + + def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool): + super().__init__() + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = Convolution( + spatial_dims=spatial_dims, + in_channels=dim, + out_channels=hidden_features * 2, + kernel_size=1, + bias=bias, + conv_only=True, + ) + + self.dwconv = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_features * 2, + out_channels=hidden_features * 2, + kernel_size=3, + strides=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + conv_only=True, + ) + + self.project_out = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_features, + out_channels=dim, + kernel_size=1, + bias=bias, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + return self.project_out(F.gelu(x1) * x2) + + +class CABlock(nn.Module): + """Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention + by operating on feature channels instead of spatial dimensions. Incorporates depth-wise + convolutions for local mixing before attention, achieving linear complexity vs quadratic + in vanilla attention. Based on SW Zamir, et al., 2022 + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + dim: Number of input channels + num_heads: Number of attention heads + bias: Whether to use bias in convolution layers + flash_attention: Whether to use flash attention optimization. Defaults to False. + + Raises: + ValueError: If flash attention is not available in current PyTorch version + ValueError: If spatial_dims is greater than 3 + """ + + def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): + super().__init__() + if flash_attention and not hasattr(F, "scaled_dot_product_attention"): + raise ValueError("Flash attention not available") + if spatial_dims > 3: + raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}") + self.spatial_dims = spatial_dims + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.flash_attention = flash_attention + + self.qkv = Convolution( + spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True + ) + + self.qkv_dwconv = Convolution( + spatial_dims=spatial_dims, + in_channels=dim * 3, + out_channels=dim * 3, + kernel_size=3, + strides=1, + padding=1, + groups=dim * 3, + bias=bias, + conv_only=True, + ) + + self.project_out = Convolution( + spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True + ) + + self._attention_fn = self._get_attention_fn() + + def _get_attention_fn(self): + if self.flash_attention: + return self._flash_attention + return self._normal_attention + + def _flash_attention(self, q, k, v): + """Flash attention implementation using scaled dot-product attention.""" + scale = float(self.temperature.mean()) + out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False) + return out + + def _normal_attention(self, q, k, v): + """Attention matrix multiplication with depth-wise convolutions.""" + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + return attn @ v + + def forward(self, x) -> torch.Tensor: + """Forward pass for MDTA attention. + 1. Apply depth-wise convolutions to Q, K, V + 2. Reshape Q, K, V for multi-head attention + 3. Compute attention matrix using flash or normal attention + 4. Reshape and project out attention output""" + spatial_dims = x.shape[2:] + + # Project and mix + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # Select attention + if self.spatial_dims == 2: + qkv_to_multihead = "b (head c) h w -> b head c (h w)" + multihead_to_qkv = "b head c (h w) -> b (head c) h w" + else: # dims == 3 + qkv_to_multihead = "b (head c) d h w -> b head c (d h w)" + multihead_to_qkv = "b head c (d h w) -> b (head c) d h w" + + # Reconstruct and project feature map + q = rearrange(q, qkv_to_multihead, head=self.num_heads) + k = rearrange(k, qkv_to_multihead, head=self.num_heads) + v = rearrange(v, qkv_to_multihead, head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + out = self._attention_fn(q, k, v) + out = rearrange( + out, + multihead_to_qkv, + head=self.num_heads, + **dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)), + ) + + return self.project_out(out) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 2a6a60ff8a..ae962287a9 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -16,8 +16,11 @@ import torch import torch.nn as nn -from monai.networks.layers.factories import Pool -from monai.utils import ensure_tuple_rep +from monai.networks.layers.factories import Conv, Pool +from monai.networks.utils import pixelunshuffle +from monai.utils import DownsampleMode, ensure_tuple_rep, look_up_option + +__all__ = ["MaxAvgPool", "DownSample", "Downsample", "SubpixelDownsample", "SubpixelDownSample", "Subpixeldownsample"] class MaxAvgPool(nn.Module): @@ -61,3 +64,238 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]). """ return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1) + + +class DownSample(nn.Sequential): + """ + Downsamples data by `scale_factor`. + + Supported modes are: + + - "conv": uses a strided convolution for learnable downsampling. + - "convgroup": uses a grouped strided convolution for efficient feature reduction. + - "nontrainable": uses :py:class:`torch.nn.Upsample` with inverse scale factor. + - "pixelunshuffle": uses :py:class:`monai.networks.blocks.PixelUnshuffle` for channel-space rearrangement. + + This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``. + Please check the link below for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + + This module can optionally take a pre-convolution + (often used to map the number of features from `in_channels` to `out_channels`). + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int | None = None, + out_channels: int | None = None, + scale_factor: Sequence[float] | float = 2, + kernel_size: Sequence[float] | float | None = None, + mode: str = "conv", # conv, convgroup, nontrainable, pixelunshuffle + pre_conv: nn.Module | str | None = "default", + post_conv: nn.Module | None = None, + bias: bool = True, + ) -> None: + """ + Downsamples data by `scale_factor`. + Supported modes are: + + - "conv": uses a strided convolution for learnable downsampling. + - "convgroup": uses a grouped strided convolution for efficient feature reduction. + - "maxpool": uses maxpooling for non-learnable downsampling. + - "avgpool": uses average pooling for non-learnable downsampling. + - "pixelunshuffle": uses :py:class:`monai.networks.blocks.SubpixelDownsample`. + + This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``. + Please check the link below for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + + This module can optionally take a pre-convolution and post-convolution + (often used to map the number of features from `in_channels` to `out_channels`). + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of channels of the input image. + out_channels: number of channels of the output image. Defaults to `in_channels`. + scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2. + kernel_size: kernel size used during convolutions. Defaults to `scale_factor`. + mode: {``"conv"``, ``"convgroup"``, ``"maxpool"``, ``"avgpool"``, ``"pixelunshuffle"``}. Defaults to ``"conv"``. + pre_conv: a conv block applied before downsampling. Defaults to "default". + When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized. + Only used in the "maxpool", "avgpool" or "pixelunshuffle" modes. + post_conv: a conv block applied after downsampling. Defaults to None. Only used in the "maxpool" and "avgpool" modes. + bias: whether to have a bias term in the default preconv and conv layers. Defaults to True. + """ + super().__init__() + + scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims) + down_mode = look_up_option(mode, DownsampleMode) + + if not kernel_size: + kernel_size_ = scale_factor_ + padding = 0 + else: + kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims) + padding = tuple((k - 1) // 2 for k in kernel_size_) + + if down_mode == DownsampleMode.CONV: + if not in_channels: + raise ValueError("in_channels needs to be specified in conv mode") + self.add_module( + "conv", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=out_channels or in_channels, + kernel_size=kernel_size_, + stride=scale_factor_, + padding=padding, + bias=bias, + ), + ) + elif down_mode == DownsampleMode.CONVGROUP: + if not in_channels: + raise ValueError("in_channels needs to be specified") + if out_channels is None: + out_channels = in_channels + groups = in_channels if out_channels % in_channels == 0 else 1 + self.add_module( + "convgroup", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size_, + stride=scale_factor_, + padding=padding, + groups=groups, + bias=bias, + ), + ) + elif down_mode == DownsampleMode.MAXPOOL: + if pre_conv == "default" and (out_channels != in_channels): + if not in_channels: + raise ValueError("in_channels needs to be specified") + self.add_module( + "preconv", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias + ), + ) + self.add_module( + "maxpool", Pool[Pool.MAX, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding) + ) + if post_conv: + self.add_module("postconv", post_conv) + + elif down_mode == DownsampleMode.AVGPOOL: + if pre_conv == "default" and (out_channels != in_channels): + if not in_channels: + raise ValueError("in_channels needs to be specified") + self.add_module( + "preconv", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias + ), + ) + self.add_module( + "avgpool", Pool[Pool.AVG, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding) + ) + if post_conv: + self.add_module("postconv", post_conv) + + elif down_mode == DownsampleMode.PIXELUNSHUFFLE: + self.add_module( + "pixelunshuffle", + SubpixelDownsample( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + scale_factor=scale_factor_[0], + conv_block=pre_conv, + bias=bias, + ), + ) + + +class SubpixelDownsample(nn.Module): + """ + Downsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images. + The module consists of two parts. First, a convolutional layer is employed + to adjust the number of channels. Secondly, a pixel unshuffle manipulation + rearranges the spatial information into channel space, effectively reducing + spatial dimensions while increasing channel depth. + + The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions + from (B, C, H*r, W*r) to (B, C*r², H, W) for 2D images or from (B, C, H*r, W*r, D*r) to (B, C*r³, H, W, D) in 3D case. + + Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2). + + See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution + Using a nEfficient Sub-Pixel Convolutional Neural Network." + + The pixel unshuffle mechanism is the inverse operation of: + https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int | None, + out_channels: int | None = None, + scale_factor: int = 2, + conv_block: nn.Module | str | None = "default", + bias: bool = True, + ) -> None: + """ + Downsamples data by rearranging spatial information into channel space. + This reduces spatial dimensions while increasing channel depth. + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of channels of the input image. + out_channels: optional number of channels of the output image. + scale_factor: factor to reduce the spatial dimensions by. Defaults to 2. + conv_block: a conv block to adjust channels before downsampling. Defaults to None. + When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized. + When ``conv_block`` is an ``nn.module``, + please ensure the input number of channels matches requirements. + bias: whether to have a bias term in the default conv_block. Defaults to True. + """ + super().__init__() + + if scale_factor <= 0: + raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.") + + self.dimensions = spatial_dims + self.scale_factor = scale_factor + + if conv_block == "default": + if not in_channels: + raise ValueError("in_channels need to be specified.") + out_channels = out_channels or in_channels + self.conv_block = Conv[Conv.CONV, self.dimensions]( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=bias + ) + elif conv_block is None: + self.conv_block = nn.Identity() + else: + self.conv_block = conv_block + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...). + Returns: + Tensor with reduced spatial dimensions and increased channel depth. + """ + x = self.conv_block(x) + if not all(d % self.scale_factor == 0 for d in x.shape[2:]): + raise ValueError( + f"All spatial dimensions {x.shape[2:]} must be evenly " f"divisible by scale_factor {self.scale_factor}" + ) + x = pixelunshuffle(x, self.dimensions, self.scale_factor) + return x + + +Downsample = DownSample +SubpixelDownSample = Subpixeldownsample = SubpixelDownsample diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py new file mode 100644 index 0000000000..b59150ad4d --- /dev/null +++ b/monai/networks/nets/restormer.py @@ -0,0 +1,336 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.networks.blocks.cablock import CABlock, FeedForward +from monai.networks.blocks.convolutions import Convolution +from monai.networks.blocks.downsample import DownSample, DownsampleMode +from monai.networks.blocks.upsample import UpSample, UpsampleMode +from monai.networks.layers.factories import Norm + + +class MDTATransformerBlock(nn.Module): + """Basic transformer unit combining MDTA and GDFN with skip connections. + Unlike standard transformers that use LayerNorm, this block uses Instance Norm + for better adaptation to image restoration tasks. + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + dim: Number of input channels + num_heads: Number of attention heads + ffn_expansion_factor: Expansion factor for feed-forward network + bias: Whether to use bias in attention layers + layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False. + flash_attention: Whether to use flash attention optimization. Defaults to False. + """ + + def __init__( + self, + spatial_dims: int, + dim: int, + num_heads: int, + ffn_expansion_factor: float, + bias: bool, + layer_norm_use_bias: bool = False, + flash_attention: bool = False, + ): + super().__init__() + self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias) + self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention) + self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias) + self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + + +class OverlapPatchEmbed(Convolution): + """Initial feature extraction using overlapped convolutions. + Unlike standard patch embeddings that use non-overlapping patches, + this approach maintains spatial continuity through 3x3 convolutions. + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + in_channels: Number of input channels + embed_dim: Dimension of embedded features. Defaults to 48. + bias: Whether to use bias in convolution layer. Defaults to False. + """ + + def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False): + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embed_dim, + kernel_size=3, + strides=1, + padding=1, + bias=bias, + conv_only=True, + ) + + +def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x) + + +class Restormer(nn.Module): + """Restormer: Efficient Transformer for High-Resolution Image Restoration. + + Implements a U-Net style architecture with transformer blocks, combining: + - Multi-scale feature processing through progressive down/upsampling + - Efficient attention via MDTA blocks + - Local feature mixing through GDFN + - Skip connections for preserving spatial details + + Architecture: + - Encoder: Progressive feature downsampling with increasing channels + - Latent: Deep feature processing at lowest resolution + - Decoder: Progressive upsampling with skip connections + - Refinement: Final feature enhancement + """ + + def __init__( + self, + spatial_dims: int = 2, + in_channels: int = 3, + out_channels: int = 3, + dim: int = 48, + num_blocks: tuple[int, ...] = (1, 1, 1, 1), + heads: tuple[int, ...] = (1, 1, 1, 1), + num_refinement_blocks: int = 4, + ffn_expansion_factor: float = 2.66, + bias: bool = False, + layer_norm_use_bias: bool = True, + dual_pixel_task: bool = False, + flash_attention: bool = False, + ) -> None: + super().__init__() + """Initialize Restormer model. + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + in_channels: Number of input image channels + out_channels: Number of output image channels + dim: Base feature dimension. Defaults to 48. + num_blocks: Number of transformer blocks at each scale. Defaults to (1,1,1,1). + heads: Number of attention heads at each scale. Defaults to (1,1,1,1). + num_refinement_blocks: Number of final refinement blocks. Defaults to 4. + ffn_expansion_factor: Expansion factor for feed-forward network. Defaults to 2.66. + bias: Whether to use bias in convolutions. Defaults to False. + layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to True. + dual_pixel_task: Enable dual-pixel specific processing. Defaults to False. + flash_attention: Use flash attention if available. Defaults to False. + + Note: + The number of blocks must be greater than 1 + The length of num_blocks and heads must be equal + All values in num_blocks must be greater than 0 + """ + # Check input parameters + assert len(num_blocks) > 1, "Number of blocks must be greater than 1" + assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal" + assert all(n > 0 for n in num_blocks), "Number of blocks must be greater than 0" + + # Initial feature extraction + self.patch_embed = OverlapPatchEmbed(spatial_dims, in_channels, dim) + self.encoder_levels = nn.ModuleList() + self.downsamples = nn.ModuleList() + self.decoder_levels = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.reduce_channels = nn.ModuleList() + num_steps = len(num_blocks) - 1 + self.num_steps = num_steps + self.spatial_dims = spatial_dims + spatial_multiplier = 2 ** (spatial_dims - 1) + + # Define encoder levels + for n in range(num_steps): + current_dim = dim * (2) ** (n) + next_dim = current_dim // spatial_multiplier + self.encoder_levels.append( + nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=current_dim, + num_heads=heads[n], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash_attention, + ) + for _ in range(num_blocks[n]) + ] + ) + ) + + self.downsamples.append( + DownSample( + spatial_dims=self.spatial_dims, + in_channels=current_dim, + out_channels=next_dim, + mode=DownsampleMode.PIXELUNSHUFFLE, + scale_factor=2, + bias=bias, + ) + ) + + # Define latent space + latent_dim = dim * (2) ** (num_steps) + self.latent = nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=latent_dim, + num_heads=heads[num_steps], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash_attention, + ) + for _ in range(num_blocks[num_steps]) + ] + ) + + # Define decoder levels + for n in reversed(range(num_steps)): + current_dim = dim * (2) ** (n) + next_dim = dim * (2) ** (n + 1) + self.upsamples.append( + UpSample( + spatial_dims=self.spatial_dims, + in_channels=next_dim, + out_channels=(current_dim), + mode=UpsampleMode.PIXELSHUFFLE, + scale_factor=2, + bias=bias, + apply_pad_pool=False, + ) + ) + + # Reduce channel layers to deal with skip connections + if n != 0: + self.reduce_channels.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=next_dim, + out_channels=current_dim, + kernel_size=1, + bias=bias, + conv_only=True, + ) + ) + decoder_dim = current_dim + else: + decoder_dim = next_dim + + self.decoder_levels.append( + nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=decoder_dim, + num_heads=heads[n], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash_attention, + ) + for _ in range(num_blocks[n]) + ] + ) + ) + + # Final refinement and output + self.refinement = nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=decoder_dim, + num_heads=heads[0], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash_attention, + ) + for _ in range(num_refinement_blocks) + ] + ) + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = Convolution( + spatial_dims=self.spatial_dims, + in_channels=dim, + out_channels=dim * 2, + kernel_size=1, + bias=bias, + conv_only=True, + ) + self.output = Convolution( + spatial_dims=self.spatial_dims, + in_channels=dim * 2, + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=1, + bias=bias, + conv_only=True, + ) + + def forward(self, x) -> torch.Tensor: + """Forward pass of Restormer. + Processes input through encoder-decoder architecture with skip connections. + Args: + inp_img: Input image tensor of shape (B, C, H, W, [D]) + + Returns: + Restored image tensor of shape (B, C, H, W, [D]) + """ + assert all( + x.shape[-i] > 2**self.num_steps for i in range(1, self.spatial_dims + 1) + ), "All spatial dimensions should be larger than 2^number_of_step" + + # Patch embedding + x = self.patch_embed(x) + skip_connections = [] + + # Encoding path + for _idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)): + x = encoder(x) + skip_connections.append(x) + x = downsample(x) + + # Latent space + x = self.latent(x) + + # Decoding path + for idx in range(len(self.decoder_levels)): + x = self.upsamples[idx](x) + x = torch.concat([x, skip_connections[-(idx + 1)]], 1) + if idx < len(self.decoder_levels) - 1: + x = self.reduce_channels[idx](x) + x = self.decoder_levels[idx](x) + + # Final refinement + x = self.refinement(x) + + if self.dual_pixel_task: + x = x + self.skip_conv(skip_connections[0]) + x = self.output(x) + else: + x = self.output(x) + + return x diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a41d4b1e33..df91c84bdf 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -49,6 +49,7 @@ "normal_init", "icnr_init", "pixelshuffle", + "pixelunshuffle", "eval_mode", "train_mode", "get_state_dict", @@ -376,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". Args: - x: Input tensor + x: Input tensor with shape BCHW[D] spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D scale_factor: factor to rescale the spatial dimensions by, must be >=1 @@ -411,6 +412,48 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch return x +def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor: + """ + Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. + Inverse operation of pixelshuffle. + + See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution + Using an Efficient Sub-Pixel Convolutional Neural Network." + + See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". + + Args: + x: Input tensor with shape BCHW[D] + spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D + scale_factor: factor to reduce the spatial dimensions by, must be >=1 + + Returns: + Unshuffled version of `x` with shape (B, C*(r**d), H/r, W/r) for 2D + or (B, C*(r**d), D/r, H/r, W/r) for 3D, where r is the scale_factor + and d is spatial_dims. + + Raises: + ValueError: When spatial dimensions are not divisible by scale_factor + """ + dim, factor = spatial_dims, scale_factor + input_size = list(x.size()) + batch_size, channels = input_size[:2] + scale_factor_mult = factor**dim + new_channels = channels * scale_factor_mult + + if any(d % factor != 0 for d in input_size[2:]): + raise ValueError( + f"All spatial dimensions must be divisible by factor {factor}. " f", spatial shape is: {input_size[2:]}" + ) + output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]] + reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], []) + + permute_indices = [0, 1] + [(2 * i + 3) for i in range(spatial_dims)] + [(2 * i + 2) for i in range(spatial_dims)] + x = x.reshape(reshaped_size).permute(permute_indices) + x = x.reshape(output_size) + return x + + @contextmanager def eval_mode(*nets: nn.Module): """ diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 8f2f400b5d..3efc9b5e7f 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -29,6 +29,7 @@ CommonKeys, CompInitMode, DiceCEReduction, + DownsampleMode, EngineStatsKeys, FastMRIKeys, ForwardMode, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 3463a92e4b..f5bb6c4c5b 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -24,6 +24,7 @@ "SplineMode", "InterpolateMode", "UpsampleMode", + "DownsampleMode", "BlendMode", "PytorchPadMode", "NdimageMode", @@ -181,6 +182,18 @@ class UpsampleMode(StrEnum): PIXELSHUFFLE = "pixelshuffle" +class DownsampleMode(StrEnum): + """ + See also: :py:class:`monai.networks.blocks.UpSample` + """ + + CONV = "conv" # e.g. using strided convolution + CONVGROUP = "convgroup" # e.g. using grouped strided convolution + PIXELUNSHUFFLE = "pixelunshuffle" + MAXPOOL = "maxpool" + AVGPOOL = "avgpool" + + class BlendMode(StrEnum): """ See also: :py:class:`monai.data.utils.compute_importance_map` diff --git a/tests/integration/test_downsample_block.py b/tests/integration/test_downsample_block.py index 34afa248ad..5e660510d4 100644 --- a/tests/integration/test_downsample_block.py +++ b/tests/integration/test_downsample_block.py @@ -17,7 +17,10 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks import MaxAvgPool +from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 @@ -35,6 +38,20 @@ ], ] +TEST_CASES_SUBPIXEL = [ + [{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)], + [{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)], + [{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)], +] + +TEST_CASES_DOWNSAMPLE = [ + [{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)], + [{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)], +] + class TestMaxAvgPool(unittest.TestCase): @@ -46,5 +63,121 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) +class TestSubpixelDownsample(unittest.TestCase): + + @parameterized.expand(TEST_CASES_SUBPIXEL) + def test_shape(self, input_param, input_shape, expected_shape): + downsampler = SubpixelDownsample(**input_param) + with eval_mode(downsampler): + result = downsampler(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_predefined_tensor(self): + test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4) + test_tensor = test_tensor.unsqueeze(0) + + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + with eval_mode(downsampler): + result = downsampler(test_tensor) + self.assertEqual(result.shape, (1, 16, 2, 2)) + self.assertTrue(torch.all(result[0, 0:3] == 0)) + self.assertTrue(torch.all(result[0, 4:7] == 1)) + self.assertTrue(torch.all(result[0, 8:11] == 2)) + self.assertTrue(torch.all(result[0, 12:15] == 3)) + + def test_reconstruction_2d(self): + input_tensor = torch.randn(1, 1, 4, 4) + down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + def test_reconstruction_3d(self): + input_tensor = torch.randn(1, 1, 4, 4, 4) + down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + def test_invalid_spatial_size(self): + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2) + with self.assertRaises(ValueError): + downsampler(torch.randn(1, 1, 3, 4)) + + def test_custom_conv_block(self): + custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1) + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv) + with eval_mode(downsampler): + result = downsampler(torch.randn(1, 1, 4, 4)) + self.assertEqual(result.shape, (1, 8, 2, 2)) + + +class TestDownSample(unittest.TestCase): + @parameterized.expand(TEST_CASES_DOWNSAMPLE) + def test_shape(self, input_param, input_shape, expected_shape): + net = DownSample(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_pre_post_conv(self): + net = DownSample( + spatial_dims=2, + in_channels=4, + out_channels=8, + mode="maxpool", + pre_conv="default", + post_conv=torch.nn.Conv2d(8, 16, 1), + ) + with eval_mode(net): + result = net(torch.randn(1, 4, 16, 16)) + self.assertEqual(result.shape, (1, 16, 8, 8)) + + def test_pixelunshuffle_equivalence(self): + class DownSampleLocal(torch.nn.Module): + def __init__(self, n_feat: int): + super().__init__() + self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) + self.pixelunshuffle = torch.nn.PixelUnshuffle(2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + return self.pixelunshuffle(x) + + n_feat = 2 + x = torch.randn(1, n_feat, 64, 64) + + fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) + + monai_down = DownSample( + spatial_dims=2, + in_channels=n_feat, + out_channels=n_feat // 2, + mode="pixelunshuffle", + pre_conv=fix_weight_conv, + ) + + local_down = DownSampleLocal(n_feat) + local_down.conv.weight.data = fix_weight_conv.weight.data.clone() + + with eval_mode(monai_down), eval_mode(local_down): + out_monai = monai_down(x) + out_local = local_down(x) + + self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5)) + + def test_invalid_mode(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, in_channels=4, mode="invalid") + + def test_missing_channels(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, mode="conv") + + if __name__ == "__main__": unittest.main() diff --git a/tests/networks/blocks/test_CABlock.py b/tests/networks/blocks/test_CABlock.py new file mode 100644 index 0000000000..42531131c5 --- /dev/null +++ b/tests/networks/blocks/test_CABlock.py @@ -0,0 +1,150 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.cablock import CABlock, FeedForward +from monai.utils import optional_import +from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose + +einops, has_einops = optional_import("einops") + + +TEST_CASES_CAB = [] +for spatial_dims in [2, 3]: + for dim in [32, 64, 128]: + for num_heads in [2, 4, 8]: + for bias in [True, False]: + test_case = [ + { + "spatial_dims": spatial_dims, + "dim": dim, + "num_heads": num_heads, + "bias": bias, + "flash_attention": False, + }, + (2, dim, *([16] * spatial_dims)), + (2, dim, *([16] * spatial_dims)), + ] + TEST_CASES_CAB.append(test_case) + + +TEST_CASES_FEEDFORWARD = [ + # Test different spatial dims, dimensions and expansion factors + [{"spatial_dims": 2, "dim": 64, "ffn_expansion_factor": 2.0, "bias": True}, (2, 64, 32, 32)], + [{"spatial_dims": 3, "dim": 128, "ffn_expansion_factor": 1.5, "bias": False}, (2, 128, 16, 16, 16)], + [{"spatial_dims": 2, "dim": 256, "ffn_expansion_factor": 1.0, "bias": True}, (1, 256, 64, 64)], +] + + +class TestFeedForward(unittest.TestCase): + + @parameterized.expand(TEST_CASES_FEEDFORWARD) + def test_shape(self, input_param, input_shape): + net = FeedForward(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, input_shape) + + def test_gating_mechanism(self): + net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True) + x = torch.ones(1, 32, 16, 16) + out = net(x) + self.assertNotEqual(torch.sum(out), torch.sum(x)) + + +class TestCABlock(unittest.TestCase): + + @parameterized.expand(TEST_CASES_CAB) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = CABlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + @skipUnless(has_einops, "Requires einops") + def test_invalid_spatial_dims(self): + with self.assertRaises(ValueError): + CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True) + + @SkipIfBeforePyTorchVersion((2, 0)) + @skipUnless(has_einops, "Requires einops") + def test_flash_attention(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) + x = torch.randn(2, 64, 32, 32).to(device) + output = block(x) + self.assertEqual(output.shape, x.shape) + + @skipUnless(has_einops, "Requires einops") + def test_temperature_parameter(self): + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) + self.assertTrue(isinstance(block.temperature, torch.nn.Parameter)) + self.assertEqual(block.temperature.shape, (4, 1, 1)) + + @skipUnless(has_einops, "Requires einops") + def test_qkv_transformation_2d(self): + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) + x = torch.randn(2, 64, 32, 32) + qkv = block.qkv(x) + self.assertEqual(qkv.shape, (2, 192, 32, 32)) + + @skipUnless(has_einops, "Requires einops") + def test_qkv_transformation_3d(self): + block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True) + x = torch.randn(2, 64, 16, 16, 16) + qkv = block.qkv(x) + self.assertEqual(qkv.shape, (2, 192, 16, 16, 16)) + + @SkipIfBeforePyTorchVersion((2, 0)) + @skipUnless(has_einops, "Requires einops") + def test_flash_vs_normal_attention(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) + block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device) + + block_normal.load_state_dict(block_flash.state_dict()) + + x = torch.randn(2, 64, 32, 32).to(device) + with torch.no_grad(): + out_flash = block_flash(x) + out_normal = block_normal(x) + + assert_allclose(out_flash, out_normal, atol=1e-4) + + @skipUnless(has_einops, "Requires einops") + def test_deterministic_small_input(self): + block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False) + with torch.no_grad(): + block.qkv.conv.weight.data.fill_(1.0) + block.qkv_dwconv.conv.weight.data.fill_(1.0) + block.temperature.data.fill_(1.0) + block.project_out.conv.weight.data.fill_(1.0) + + x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32) + + output = block(x) + # Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72 + expected = torch.full_like(x, 72.0) + + assert_allclose(output, expected, atol=1e-6) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/blocks/test_downsample_block.py b/tests/networks/blocks/test_downsample_block.py new file mode 100644 index 0000000000..5e660510d4 --- /dev/null +++ b/tests/networks/blocks/test_downsample_block.py @@ -0,0 +1,183 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASES = [ + [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 + [{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16 + [{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16 + [ # 4-channel 3D, batch 16 + {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True}, + (16, 4, 32, 24, 48), + (16, 8, 11, 8, 16), + ], + [ # 1-channel 3D, batch 16 + {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": False}, + (16, 1, 32, 24, 48), + (16, 2, 10, 8, 16), + ], +] + +TEST_CASES_SUBPIXEL = [ + [{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)], + [{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)], + [{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)], +] + +TEST_CASES_DOWNSAMPLE = [ + [{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)], + [{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)], +] + + +class TestMaxAvgPool(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + net = MaxAvgPool(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +class TestSubpixelDownsample(unittest.TestCase): + + @parameterized.expand(TEST_CASES_SUBPIXEL) + def test_shape(self, input_param, input_shape, expected_shape): + downsampler = SubpixelDownsample(**input_param) + with eval_mode(downsampler): + result = downsampler(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_predefined_tensor(self): + test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4) + test_tensor = test_tensor.unsqueeze(0) + + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + with eval_mode(downsampler): + result = downsampler(test_tensor) + self.assertEqual(result.shape, (1, 16, 2, 2)) + self.assertTrue(torch.all(result[0, 0:3] == 0)) + self.assertTrue(torch.all(result[0, 4:7] == 1)) + self.assertTrue(torch.all(result[0, 8:11] == 2)) + self.assertTrue(torch.all(result[0, 12:15] == 3)) + + def test_reconstruction_2d(self): + input_tensor = torch.randn(1, 1, 4, 4) + down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + def test_reconstruction_3d(self): + input_tensor = torch.randn(1, 1, 4, 4, 4) + down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + def test_invalid_spatial_size(self): + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2) + with self.assertRaises(ValueError): + downsampler(torch.randn(1, 1, 3, 4)) + + def test_custom_conv_block(self): + custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1) + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv) + with eval_mode(downsampler): + result = downsampler(torch.randn(1, 1, 4, 4)) + self.assertEqual(result.shape, (1, 8, 2, 2)) + + +class TestDownSample(unittest.TestCase): + @parameterized.expand(TEST_CASES_DOWNSAMPLE) + def test_shape(self, input_param, input_shape, expected_shape): + net = DownSample(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_pre_post_conv(self): + net = DownSample( + spatial_dims=2, + in_channels=4, + out_channels=8, + mode="maxpool", + pre_conv="default", + post_conv=torch.nn.Conv2d(8, 16, 1), + ) + with eval_mode(net): + result = net(torch.randn(1, 4, 16, 16)) + self.assertEqual(result.shape, (1, 16, 8, 8)) + + def test_pixelunshuffle_equivalence(self): + class DownSampleLocal(torch.nn.Module): + def __init__(self, n_feat: int): + super().__init__() + self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) + self.pixelunshuffle = torch.nn.PixelUnshuffle(2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + return self.pixelunshuffle(x) + + n_feat = 2 + x = torch.randn(1, n_feat, 64, 64) + + fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) + + monai_down = DownSample( + spatial_dims=2, + in_channels=n_feat, + out_channels=n_feat // 2, + mode="pixelunshuffle", + pre_conv=fix_weight_conv, + ) + + local_down = DownSampleLocal(n_feat) + local_down.conv.weight.data = fix_weight_conv.weight.data.clone() + + with eval_mode(monai_down), eval_mode(local_down): + out_monai = monai_down(x) + out_local = local_down(x) + + self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5)) + + def test_invalid_mode(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, in_channels=4, mode="invalid") + + def test_missing_channels(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, mode="conv") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py new file mode 100644 index 0000000000..7259766bd0 --- /dev/null +++ b/tests/networks/nets/test_restormer.py @@ -0,0 +1,156 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASES_TRANSFORMER = [ + # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape] + [2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)], + [2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)], + [3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)], + [3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)], +] + +TEST_CASES_PATCHEMBED = [ + # spatial_dims, in_channels, embed_dim, input_shape, expected_shape + [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)], + [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)], + [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)], + [3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)], +] + +RESTORMER_CONFIGS = [ + # 2-level architecture test + {"num_blocks": [1, 1], "heads": [1, 1]}, + {"num_blocks": [2, 1], "heads": [2, 1]}, + # 3-level architecture test + {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]}, + {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]}, +] + +TEST_CASES_RESTORMER = [] +for config in RESTORMER_CONFIGS: + # 2D cases + TEST_CASES_RESTORMER.extend( + [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "dim": 48, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 64, 64), + (2, 1, 64, 64), + ], + # 3D cases + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "dim": 16, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 32, 32, 32), + (2, 1, 32, 32, 32), + ], + ] + ) + + +if has_einops: + class TestMDTATransformerBlock(unittest.TestCase): + + @parameterized.expand(TEST_CASES_TRANSFORMER) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): + if flash and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + block = MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=dim, + num_heads=heads, + ffn_expansion_factor=ffn_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash, + ) + with eval_mode(block): + x = torch.randn(shape) + output = block(x) + self.assertEqual(output.shape, x.shape) +else: + class TestMDTATransformerBlock(unittest.TestCase): + def test_placeholder(self): + self.skipTest("Einops module not available") + + +class TestOverlapPatchEmbed(unittest.TestCase): + + @parameterized.expand(TEST_CASES_PATCHEMBED) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): + net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + +if has_einops: + class TestRestormer(unittest.TestCase): + + @parameterized.expand(TEST_CASES_RESTORMER) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + if input_param.get("flash_attention", False) and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + net = Restormer(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_2d(self): + net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8)) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_3d(self): + net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8, 8)) +else: + class TestRestormer(unittest.TestCase): + def test_placeholder(self): + self.skipTest("Einops module not available") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/utils/test_pixelunshuffle.py b/tests/networks/utils/test_pixelunshuffle.py new file mode 100644 index 0000000000..49b61440e5 --- /dev/null +++ b/tests/networks/utils/test_pixelunshuffle.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.utils import pixelshuffle, pixelunshuffle + + +class TestPixelUnshuffle(unittest.TestCase): + + def test_2d_basic(self): + x = torch.randn(2, 4, 16, 16) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + self.assertEqual(out.shape, (2, 16, 8, 8)) + + def test_3d_basic(self): + x = torch.randn(2, 4, 16, 16, 16) + out = pixelunshuffle(x, spatial_dims=3, scale_factor=2) + self.assertEqual(out.shape, (2, 32, 8, 8, 8)) + + def test_non_square_input(self): + x = torch.arange(192).reshape(1, 2, 12, 8) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 2)) + + def test_different_scale_factor(self): + x = torch.arange(360).reshape(1, 2, 12, 15) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=3) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 3)) + + def test_inverse_operation(self): + x = torch.arange(4096).reshape(1, 8, 8, 8, 8) + shuffled = pixelshuffle(x, spatial_dims=3, scale_factor=2) + unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2) + torch.testing.assert_close(x, unshuffled) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py new file mode 100644 index 0000000000..49b61440e5 --- /dev/null +++ b/tests/test_pixelunshuffle.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.utils import pixelshuffle, pixelunshuffle + + +class TestPixelUnshuffle(unittest.TestCase): + + def test_2d_basic(self): + x = torch.randn(2, 4, 16, 16) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + self.assertEqual(out.shape, (2, 16, 8, 8)) + + def test_3d_basic(self): + x = torch.randn(2, 4, 16, 16, 16) + out = pixelunshuffle(x, spatial_dims=3, scale_factor=2) + self.assertEqual(out.shape, (2, 32, 8, 8, 8)) + + def test_non_square_input(self): + x = torch.arange(192).reshape(1, 2, 12, 8) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 2)) + + def test_different_scale_factor(self): + x = torch.arange(360).reshape(1, 2, 12, 15) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=3) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 3)) + + def test_inverse_operation(self): + x = torch.arange(4096).reshape(1, 8, 8, 8, 8) + shuffled = pixelshuffle(x, spatial_dims=3, scale_factor=2) + unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2) + torch.testing.assert_close(x, unshuffled) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_restormer.py b/tests/test_restormer.py new file mode 100644 index 0000000000..ab08d84390 --- /dev/null +++ b/tests/test_restormer.py @@ -0,0 +1,147 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASES_TRANSFORMER = [ + # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape] + [2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)], + [2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)], + [3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)], + [3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)], +] + +TEST_CASES_PATCHEMBED = [ + # spatial_dims, in_channels, embed_dim, input_shape, expected_shape + [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)], + [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)], + [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)], + [3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)], +] + +RESTORMER_CONFIGS = [ + # 2-level architecture + {"num_blocks": [1, 1], "heads": [1, 1]}, + {"num_blocks": [2, 1], "heads": [2, 1]}, + # 3-level architecture + {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]}, + {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]}, +] + +TEST_CASES_RESTORMER = [] +for config in RESTORMER_CONFIGS: + # 2D cases + TEST_CASES_RESTORMER.extend( + [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "dim": 48, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 64, 64), + (2, 1, 64, 64), + ], + # 3D cases + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "dim": 16, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 32, 32, 32), + (2, 1, 32, 32, 32), + ], + ] + ) + + +class TestMDTATransformerBlock(unittest.TestCase): + + @skipUnless(has_einops, "Requires einops") + @parameterized.expand(TEST_CASES_TRANSFORMER) + def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): + if flash and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + block = MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=dim, + num_heads=heads, + ffn_expansion_factor=ffn_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash, + ) + with eval_mode(block): + x = torch.randn(shape) + output = block(x) + self.assertEqual(output.shape, x.shape) + + +class TestOverlapPatchEmbed(unittest.TestCase): + + @parameterized.expand(TEST_CASES_PATCHEMBED) + def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): + net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +class TestRestormer(unittest.TestCase): + + @skipUnless(has_einops, "Requires einops") + @parameterized.expand(TEST_CASES_RESTORMER) + def test_shape(self, input_param, input_shape, expected_shape): + if input_param.get("flash_attention", False) and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + net = Restormer(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_2d(self): + net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8)) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_3d(self): + net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8, 8)) + + +if __name__ == "__main__": + unittest.main()