Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restormer Implementation #8312

Open
wants to merge 66 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
3db93ce
Add new pixel unshuffle for SubPixelDownsample class
phisanti Jan 15, 2025
9693e04
Add unit test for pixelunshuffle
phisanti Jan 15, 2025
a89f299
Add DownSample Modes
phisanti Jan 15, 2025
450691f
expand pixelunshuffle for 3D
phisanti Jan 16, 2025
d0920d8
increase testing for pixelunshuffle
phisanti Jan 16, 2025
1a48d4d
expand pixelunshuffle for 3D images
phisanti Jan 16, 2025
fe47807
add SubpixelDownsample and tests
phisanti Jan 16, 2025
86155cd
Add DownSample Class
phisanti Jan 16, 2025
137a7f2
Add tests for Downsample
phisanti Jan 16, 2025
fb17baf
add exports to __init__
phisanti Jan 16, 2025
5ff0baa
Include test to compare with Conv + unshuffle from original restormer
phisanti Jan 16, 2025
2566db1
remove relative imports
phisanti Jan 16, 2025
ac4047b
Create restormer with Downsampler/Upsampler using monai implementation
phisanti Jan 17, 2025
2b74270
Add channel attention block
phisanti Jan 17, 2025
9b74533
add assembled restormer with MONAI convs for 3D
phisanti Jan 17, 2025
1ab34f6
restormer adapted for 2D/3D
phisanti Jan 20, 2025
4f4c62c
Add unit test for CABlock and the FeedForward layers
phisanti Jan 20, 2025
068688f
remove relative imports
phisanti Jan 20, 2025
e2e1070
rename restormer
phisanti Jan 20, 2025
35c7ee4
add unit test restormer
phisanti Jan 20, 2025
d8cb6c1
Update documentation and imports for CABlock and FeedForward; add Dow…
phisanti Jan 23, 2025
6d96816
Add licence to pixel_unshuffle test
phisanti Jan 23, 2025
8a688fb
Refactor imports and clean up whitespace in utils and test files and …
phisanti Jan 23, 2025
acb818d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2025
6352ba9
DCO Remediation Commit for tisalon <[email protected]>
phisanti Jan 23, 2025
c7b1af4
add optional_import to downsample block test
phisanti Jan 24, 2025
8faa5da
rename args and fix imports
phisanti Feb 7, 2025
be89958
Using LocalStore in Zarr v3 (#8299)
KumoLiu Jan 15, 2025
c17938b
8267 fix normalize intensity (#8286)
advcu987 Jan 20, 2025
64613a7
Fix bundle download error from ngc source (#8307)
KumoLiu Jan 21, 2025
5643d4a
Fix deprecated usage in zarr (#8313)
KumoLiu Jan 24, 2025
595674a
update pydicom reader to enable gpu load (#8283)
yiheng-wang-nv Jan 27, 2025
c775393
Zarr compression tests only with versions before 3.0 (#8319)
ericspod Feb 3, 2025
61efefb
Sync dev branch with upstream MONAI changes
phisanti Feb 7, 2025
091887b
Clarify input tensor shape in pixelshuffle and pixelunshuffle functio…
phisanti Feb 7, 2025
5d162d0
Refactor downsample mode checks to use enum values for clarity
phisanti Feb 7, 2025
f520e99
fix optiona import
phisanti Feb 7, 2025
39d1edf
Refactor layer normalization parameters for consistency and clarity i…
phisanti Feb 7, 2025
5b3d4e1
Enhance documentation for MDTATransformerBlock, OverlapPatchEmbed an…
phisanti Feb 7, 2025
1683b14
run ./runtests.sh --autofix to check formatting
phisanti Feb 7, 2025
232be1c
Refactor OverlapPatchEmbed to inherit from Convolution and streamline…
phisanti Feb 7, 2025
d1df8e6
Enhance documentation for FeedForward and CABlock classes, adding arg…
phisanti Feb 7, 2025
78ce56b
code formatting
phisanti Feb 7, 2025
64b203d
Update args naming in unit restormer test for consistency with sugges…
phisanti Feb 7, 2025
ce15886
Fix optional import
phisanti Feb 7, 2025
30fad17
require einops for all tests
phisanti Feb 7, 2025
1079d8c
require einops also for test_restormer
phisanti Feb 7, 2025
b2b3ddf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2025
174e968
remove relative impots
phisanti Feb 7, 2025
e15a815
fix capitalisation in DownSample documentation networks.rts
phisanti Feb 7, 2025
d53d97d
fix capitalisation in SubpixelDownsample documentation
phisanti Feb 7, 2025
cae7d96
formatting
phisanti Feb 7, 2025
a0afee5
update docstring to mention 2D and 3D cases
phisanti Feb 7, 2025
529e90b
Update type annotations and doctring
phisanti Feb 9, 2025
c109029
remove problematic unit test
phisanti Feb 9, 2025
19c30f7
remove problematic unit test
phisanti Feb 9, 2025
0b0e4df
Merge remote-tracking branch 'upstream/dev' into dev
phisanti Mar 1, 2025
55da640
relocate test in the correct place
phisanti Mar 1, 2025
3c2dbc6
Add DownSampleBlock missing tests, Signed-off-by: Santiago Cano-Muniz…
phisanti Mar 1, 2025
da0a186
Merge branch 'dev' into dev
phisanti Mar 8, 2025
f17e06e
Re-order skipUnless in test_restormer.py, Signed-off-by: Cano-Muniz, …
phisanti Mar 8, 2025
4573ec9
Clarify comments for RESTORMER_CONFIGS in test_restormer.py,
phisanti Mar 8, 2025
8c564aa
Remove duplicated test_CABlock.py as part of codebase cleanup. In add…
phisanti Mar 8, 2025
3e013fe
Refactor test cases in test_restormer.py to conditionally define clas…
phisanti Mar 8, 2025
06be2ef
formatting error in line 237. Solved by updating black from 24.10.0 t…
phisanti Mar 8, 2025
c02d794
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ Blocks
.. autoclass:: SABlock
:members:

`CABlock Block`
~~~~~~~~~~~~~~~
.. autoclass:: CABlock
:members:

`FeedForward Block`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: FeedForward
:members:

`Squeeze-and-Excitation`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ChannelSELayer
Expand Down Expand Up @@ -173,6 +183,16 @@ Blocks
.. autoclass:: Subpixelupsample
.. autoclass:: SubpixelUpSample

`Downsampling`
~~~~~~~~~~~~~~
.. autoclass:: DownSample
:members:
.. autoclass:: Downsample
.. autoclass:: SubpixelDownsample
:members:
.. autoclass:: Subpixeldownsample
.. autoclass:: SubpixelDownSample

`Registration Residual Conv Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: RegistrationResidualConvBlock
Expand Down Expand Up @@ -625,6 +645,11 @@ Nets
.. autoclass:: ViT
:members:

`Restormer`
~~~~~~~~~~~
.. autoclass:: restormer
:members:

`ViTAutoEnc`
~~~~~~~~~~~~
.. autoclass:: ViTAutoEnc
Expand Down
3 changes: 2 additions & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish
from .aspp import SimpleASPP
from .backbone_fpn_utils import BackboneWithFPN
from .cablock import CABlock, FeedForward
from .convolutions import Convolution, ResidualUnit
from .crf import CRF
from .crossattention import CrossAttentionBlock
from .denseblock import ConvDenseBlock, DenseBlock
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
from .downsample import MaxAvgPool
from .downsample import DownSample, Downsample, MaxAvgPool, SubpixelDownsample, SubpixelDownSample, Subpixeldownsample
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
from .encoder import BaseEncoder
from .fcn import FCN, GCN, MCFCN, Refine
Expand Down
180 changes: 180 additions & 0 deletions monai/networks/blocks/cablock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks.convolutions import Convolution
from monai.utils import optional_import

rearrange, _ = optional_import("einops", name="rearrange")

__all__ = ["FeedForward", "CABlock"]


class FeedForward(nn.Module):
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.

Args:
spatial_dims: Number of spatial dimensions (2D or 3D)
dim: Number of input channels
ffn_expansion_factor: Factor to expand hidden features dimension
bias: Whether to use bias in convolution layers
"""

def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool):
super().__init__()
hidden_features = int(dim * ffn_expansion_factor)

self.project_in = Convolution(
spatial_dims=spatial_dims,
in_channels=dim,
out_channels=hidden_features * 2,
kernel_size=1,
bias=bias,
conv_only=True,
)

self.dwconv = Convolution(
spatial_dims=spatial_dims,
in_channels=hidden_features * 2,
out_channels=hidden_features * 2,
kernel_size=3,
strides=1,
padding=1,
groups=hidden_features * 2,
bias=bias,
conv_only=True,
)

self.project_out = Convolution(
spatial_dims=spatial_dims,
in_channels=hidden_features,
out_channels=dim,
kernel_size=1,
bias=bias,
conv_only=True,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
return self.project_out(F.gelu(x1) * x2)


class CABlock(nn.Module):
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>

Args:
spatial_dims: Number of spatial dimensions (2D or 3D)
dim: Number of input channels
num_heads: Number of attention heads
bias: Whether to use bias in convolution layers
flash_attention: Whether to use flash attention optimization. Defaults to False.

Raises:
ValueError: If flash attention is not available in current PyTorch version
ValueError: If spatial_dims is greater than 3
"""

def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
super().__init__()
if flash_attention and not hasattr(F, "scaled_dot_product_attention"):
raise ValueError("Flash attention not available")
if spatial_dims > 3:
raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}")
self.spatial_dims = spatial_dims
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.flash_attention = flash_attention

self.qkv = Convolution(
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True
)

self.qkv_dwconv = Convolution(
spatial_dims=spatial_dims,
in_channels=dim * 3,
out_channels=dim * 3,
kernel_size=3,
strides=1,
padding=1,
groups=dim * 3,
bias=bias,
conv_only=True,
)

self.project_out = Convolution(
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True
)

self._attention_fn = self._get_attention_fn()

def _get_attention_fn(self):
if self.flash_attention:
return self._flash_attention
return self._normal_attention

def _flash_attention(self, q, k, v):
"""Flash attention implementation using scaled dot-product attention."""
scale = float(self.temperature.mean())
out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False)
return out

def _normal_attention(self, q, k, v):
"""Attention matrix multiplication with depth-wise convolutions."""
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
return attn @ v

def forward(self, x) -> torch.Tensor:
"""Forward pass for MDTA attention.
1. Apply depth-wise convolutions to Q, K, V
2. Reshape Q, K, V for multi-head attention
3. Compute attention matrix using flash or normal attention
4. Reshape and project out attention output"""
spatial_dims = x.shape[2:]

# Project and mix
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)

# Select attention
if self.spatial_dims == 2:
qkv_to_multihead = "b (head c) h w -> b head c (h w)"
multihead_to_qkv = "b head c (h w) -> b (head c) h w"
else: # dims == 3
qkv_to_multihead = "b (head c) d h w -> b head c (d h w)"
multihead_to_qkv = "b head c (d h w) -> b (head c) d h w"

# Reconstruct and project feature map
q = rearrange(q, qkv_to_multihead, head=self.num_heads)
k = rearrange(k, qkv_to_multihead, head=self.num_heads)
v = rearrange(v, qkv_to_multihead, head=self.num_heads)

q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)

out = self._attention_fn(q, k, v)
out = rearrange(
out,
multihead_to_qkv,
head=self.num_heads,
**dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)),
)

return self.project_out(out)
Loading
Loading