From 8888a488a37a499d0b8b5128123084f7a97703c1 Mon Sep 17 00:00:00 2001 From: John Zielke Date: Tue, 4 Feb 2025 14:52:12 +0000 Subject: [PATCH] MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 Signed-off-by: John Zielke --- .../maisi/networks/autoencoderkl_maisi.py | 32 ++++++++----------- tests/test_autoencoderkl_maisi.py | 26 ++++++++++----- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 6251ea8e83..ea63114b5d 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -22,6 +22,7 @@ from monai.networks.blocks import Convolution from monai.networks.blocks.spatialattention import SpatialAttentionBlock from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL +from monai.utils.deprecate_utils import deprecated_arg from monai.utils.type_conversion import convert_to_tensor # Set up logging configuration @@ -34,6 +35,7 @@ def _empty_cuda_cache(save_mem: bool) -> None: return +@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0") class MaisiGroupNorm3D(nn.GroupNorm): """ Custom 3D Group Normalization with optional print_info output. @@ -43,7 +45,7 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. affine: Whether to use learnable affine parameters, default to `True`. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: Deprecated argument. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -59,7 +61,6 @@ def __init__( save_mem: bool = True, ): super().__init__(num_groups, num_channels, eps, affine) - self.norm_float16 = norm_float16 self.print_info = print_info self.save_mem = save_mem @@ -67,6 +68,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.print_info: logger.info(f"MaisiGroupNorm3D with input size: {input.size()}") + target_dtype = input.dtype + if len(input.shape) != 5: raise ValueError("Expected a 5D tensor") @@ -75,13 +78,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: inputs = [] for i in range(input.size(1)): - array = input[:, i : i + 1, ...].to(dtype=torch.float32) + array = input[:, i : i + 1, ...] mean = array.mean([2, 3, 4, 5], keepdim=True) std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() - if self.norm_float16: - inputs.append(((array - mean) / std).to(dtype=torch.float16)) - else: - inputs.append((array - mean) / std) + inputs.append(((array - mean) / std).to(dtype=target_dtype)) del input _empty_cuda_cache(self.save_mem) @@ -376,6 +376,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0") class MaisiResBlock(nn.Module): """ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a @@ -417,7 +418,6 @@ def __init__( num_channels=in_channels, eps=norm_eps, affine=True, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) @@ -439,7 +439,6 @@ def __init__( num_channels=out_channels, eps=norm_eps, affine=True, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) @@ -501,6 +500,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out_tensor +@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0") class MaisiEncoder(nn.Module): """ Convolutional cascade that downsamples the image into a spatial latent space. @@ -520,7 +520,7 @@ class MaisiEncoder(nn.Module): use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: Deprecated argument. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -591,7 +591,6 @@ def __init__( out_channels=output_channel, num_splits=num_splits, dim_split=dim_split, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) @@ -660,7 +659,6 @@ def __init__( num_channels=num_channels[-1], eps=norm_eps, affine=True, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) @@ -690,6 +688,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0") class MaisiDecoder(nn.Module): """ Convolutional cascade upsampling from a spatial latent space into an image space. @@ -710,7 +709,7 @@ class MaisiDecoder(nn.Module): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: Deprecated argument. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -809,7 +808,6 @@ def __init__( out_channels=block_out_ch, num_splits=num_splits, dim_split=dim_split, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) @@ -848,7 +846,6 @@ def __init__( num_channels=block_in_ch, eps=norm_eps, affine=True, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) @@ -878,6 +875,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0") class AutoencoderKlMaisi(AutoencoderKL): """ AutoencoderKL with custom MaisiEncoder and MaisiDecoder. @@ -901,7 +899,7 @@ class AutoencoderKlMaisi(AutoencoderKL): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: Deprecated argument. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -964,7 +962,6 @@ def __init__( use_flash_attention=use_flash_attention, num_splits=num_splits, dim_split=dim_split, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) @@ -985,7 +982,6 @@ def __init__( use_convtranspose=use_convtranspose, num_splits=num_splits, dim_split=dim_split, - norm_float16=norm_float16, print_info=print_info, save_mem=save_mem, ) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 0e9f427fb6..81a799e1ba 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -75,28 +75,38 @@ else: CASES = CASES_NO_ATTENTION +test_dtypes = [torch.float32] +if device.type == "cuda": + test_dtypes.append(torch.bfloat16) + test_dtypes.append(torch.float16) + +DTYPE_CASES = [] +for dtype in test_dtypes: + for case in CASES: + DTYPE_CASES.append(case + [dtype]) + class TestAutoencoderKlMaisi(unittest.TestCase): - @parameterized.expand(CASES) - def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): - net = AutoencoderKlMaisi(**input_param).to(device) + @parameterized.expand(DTYPE_CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype): + net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype) with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device)) + result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype)) self.assertEqual(result[0].shape, expected_shape) self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape) - @parameterized.expand(CASES) + @parameterized.expand(DTYPE_CASES) @SkipIfBeforePyTorchVersion((1, 11)) def test_shape_with_convtranspose_and_checkpointing( - self, input_param, input_shape, expected_shape, expected_latent_shape + self, input_param, input_shape, expected_shape, expected_latent_shape, dtype ): input_param = input_param.copy() input_param.update({"use_checkpointing": True, "use_convtranspose": True}) - net = AutoencoderKlMaisi(**input_param).to(device) + net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype) with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device)) + result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype)) self.assertEqual(result[0].shape, expected_shape) self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape)