Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6676 port generative inferers #7379

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fd29603
Adds inferers. Changes arg in the spade_diffusion_model from num_chan…
marksgraham Dec 20, 2023
495758f
Updates docs
marksgraham Dec 20, 2023
3ebaf9f
Start to address mypy issues, inc changing base class from Inferers
marksgraham Dec 21, 2023
740467b
Merge branch 'gen-ai-dev' into 6676_port_generative_inferers
marksgraham Jan 4, 2024
8fef41c
Address more mypy
marksgraham Jan 4, 2024
13fd339
Merge branch 'gen-ai-dev' into 6676_port_generative_inferers
marksgraham Jan 8, 2024
9f5a903
Inferers mypy compatible
marksgraham Jan 9, 2024
0ded5a3
DCO Remediation Commit for Mark Graham <[email protected]>
marksgraham Jan 9, 2024
25f06d6
DCO
marksgraham Jan 9, 2024
818eb68
Skip test if scipy not installed
marksgraham Jan 9, 2024
74d7663
Skip test if scipy not installed
marksgraham Jan 9, 2024
6b0b389
Try to correct non-contiguous error
marksgraham Jan 9, 2024
a1e1bda
Contigous again
marksgraham Jan 9, 2024
a15da83
Adds missing VQVAETranformerInferer tests
marksgraham Jan 10, 2024
9cb196d
Formatting
marksgraham Jan 10, 2024
ecc1d7c
Update monai/inferers/inferer.py
marksgraham Jan 10, 2024
f0f53e5
Remove unnecessary partial calls, increase test coverage
marksgraham Jan 10, 2024
5c018cf
Test if changing inferer inheritance affects contiguous error
marksgraham Jan 10, 2024
86b21e8
contig
marksgraham Jan 10, 2024
d654216
contig
marksgraham Jan 10, 2024
ccc3110
undo
marksgraham Jan 10, 2024
22ba322
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
15af706
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
2f6bda5
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
22bf240
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
095be61
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
cc28b20
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
3ba1363
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
f162dee
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
4c2085c
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
4c6d788
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
47b5958
Update monai/inferers/inferer.py
marksgraham Jan 11, 2024
97bd662
Updates to comments
marksgraham Jan 11, 2024
553c94b
Move tests
marksgraham Jan 15, 2024
38f832a
DCO Remediation Commit for Mark Graham <[email protected]>
marksgraham Jan 15, 2024
7ef3fb5
DCO
marksgraham Jan 15, 2024
ac891d8
Updates setup.cof to fix premerge
marksgraham Jan 17, 2024
9073d85
Fixes to tests for premerge
marksgraham Jan 17, 2024
1d8e7cc
Remove random test
marksgraham Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/inferers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ Inferers
:members:
:special-members: __call__

`DiffusionInferer`
~~~~~~~~~~~~~~~~~~
.. autoclass:: DiffusionInferer
:members:
:special-members: __call__

`LatentDiffusionInferer`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LatentDiffusionInferer
:members:
:special-members: __call__

`ControlNetDiffusionInferer`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ControlNetDiffusionInferer
:members:
:special-members: __call__

`ControlNetLatentDiffusionInferer`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ControlNetLatentDiffusionInferer
:members:
:special-members: __call__

Splitters
---------
Expand Down
5 changes: 5 additions & 0 deletions monai/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
from __future__ import annotations

from .inferer import (
ControlNetDiffusionInferer,
ControlNetLatentDiffusionInferer,
DiffusionInferer,
Inferer,
LatentDiffusionInferer,
PatchInferer,
SaliencyInferer,
SimpleInferer,
SliceInferer,
SlidingWindowInferer,
SlidingWindowInfererAdapt,
VQVAETransformerInferer,
)
from .merger import AvgMerger, Merger, ZarrAvgMerger
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
Expand Down
1,280 changes: 1,279 additions & 1 deletion monai/inferers/inferer.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width, depth = x.shape

# norm
x = self.norm(x)
x = self.norm(x.contiguous())

if self.spatial_dims == 2:
x = x.view(batch, channel, height * width).transpose(1, 2)
Expand Down Expand Up @@ -682,7 +682,7 @@ def __init__(
)

def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
h = x
h = x.contiguous()
h = self.norm1(h)
h = self.nonlinearity(h)

Expand Down Expand Up @@ -1957,7 +1957,7 @@ def forward(
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)

# 7. output block
output: torch.Tensor = self.out(h)
output: torch.Tensor = self.out(h.contiguous())

return output

Expand Down
40 changes: 20 additions & 20 deletions monai/networks/nets/spade_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ class SPADEDiffusionModelUNet(nn.Module):
out_channels: number of output channels.
label_nc: number of semantic channels for SPADE normalisation.
num_res_blocks: number of residual blocks (see ResnetBlock) per level.
num_channels: tuple of block output channels.
channels: tuple of block output channels.
attention_levels: list of levels to add attention.
norm_num_groups: number of groups for the normalization.
norm_eps: epsilon for the normalization.
Expand All @@ -641,7 +641,7 @@ def __init__(
out_channels: int,
label_nc: int,
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
num_channels: Sequence[int] = (32, 64, 64, 64),
channels: Sequence[int] = (32, 64, 64, 64),
attention_levels: Sequence[bool] = (False, False, True, True),
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
Expand All @@ -667,10 +667,10 @@ def __init__(
)

# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups")

if len(num_channels) != len(attention_levels):
if len(channels) != len(attention_levels):
raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels")

if isinstance(num_head_channels, int):
Expand All @@ -683,9 +683,9 @@ def __init__(
)

if isinstance(num_res_blocks, int):
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))

if len(num_res_blocks) != len(num_channels):
if len(num_res_blocks) != len(channels):
raise ValueError(
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
"`num_channels`."
Expand All @@ -700,7 +700,7 @@ def __init__(
)

self.in_channels = in_channels
self.block_out_channels = num_channels
self.block_out_channels = channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_levels = attention_levels
Expand All @@ -712,17 +712,17 @@ def __init__(
self.conv_in = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=num_channels[0],
out_channels=channels[0],
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)

# time
time_embed_dim = num_channels[0] * 4
time_embed_dim = channels[0] * 4
self.time_embed = nn.Sequential(
nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
)

# class embedding
Expand All @@ -732,11 +732,11 @@ def __init__(

# down
self.down_blocks = nn.ModuleList([])
output_channel = num_channels[0]
for i in range(len(num_channels)):
output_channel = channels[0]
for i in range(len(channels)):
input_channel = output_channel
output_channel = num_channels[i]
is_final_block = i == len(num_channels) - 1
output_channel = channels[i]
is_final_block = i == len(channels) - 1

down_block = get_down_block(
spatial_dims=spatial_dims,
Expand All @@ -762,7 +762,7 @@ def __init__(
# mid
self.middle_block = get_mid_block(
spatial_dims=spatial_dims,
in_channels=num_channels[-1],
in_channels=channels[-1],
temb_channels=time_embed_dim,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
Expand All @@ -776,17 +776,17 @@ def __init__(

# up
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(num_channels))
reversed_block_out_channels = list(reversed(channels))
reversed_num_res_blocks = list(reversed(num_res_blocks))
reversed_attention_levels = list(reversed(attention_levels))
reversed_num_head_channels = list(reversed(num_head_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(reversed_block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)]
input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]

is_final_block = i == len(num_channels) - 1
is_final_block = i == len(channels) - 1

up_block = get_spade_up_block(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -814,12 +814,12 @@ def __init__(

# out
self.out = nn.Sequential(
nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True),
nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),
nn.SiLU(),
zero_module(
Convolution(
spatial_dims=spatial_dims,
in_channels=num_channels[0],
in_channels=channels[0],
out_channels=out_channels,
strides=1,
kernel_size=3,
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
version_leq,
)
from .nvtx import Range
from .ordering import Ordering
from .profiling import (
PerfContext,
ProfileHandler,
Expand Down
14 changes: 8 additions & 6 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ all =
scipy>=1.7.1
pillow
tensorboard
gdown>=4.4.0
gdown==4.6.3
pytorch-ignite==0.4.11
torchvision
itk>=5.2
tqdm>=4.47.0
lmdb
psutil
cucim>=23.2.0
openslide-python==1.1.2
openslide-python
tifffile
imagecodecs
pandas
einops
transformers<4.22
transformers<4.22; python_version <= '3.10'
mlflow>=1.28.0
clearml>=1.10.0rc0
matplotlib
Expand Down Expand Up @@ -97,7 +97,7 @@ pillow =
tensorboard =
tensorboard
gdown =
gdown>=4.4.0
gdown==4.6.3
ignite =
pytorch-ignite==0.4.11
torchvision =
Expand All @@ -113,7 +113,7 @@ psutil =
cucim =
cucim>=23.2.0
openslide =
openslide-python==1.1.2
openslide-python
tifffile =
tifffile
imagecodecs =
Expand All @@ -123,7 +123,7 @@ pandas =
einops =
einops
transformers =
transformers<4.22
transformers<4.22; python_version <= '3.10'
mlflow =
mlflow
matplotlib =
Expand Down Expand Up @@ -173,6 +173,7 @@ max_line_length = 120
# B028 https://github.com/Project-MONAI/MONAI/issues/5855
# B907 https://github.com/Project-MONAI/MONAI/issues/5868
# B908 https://github.com/Project-MONAI/MONAI/issues/6503
# B036 https://github.com/Project-MONAI/MONAI/issues/7396
ignore =
E203
E501
Expand All @@ -186,6 +187,7 @@ ignore =
B028
B907
B908
B036
per_file_ignores = __init__.py: F401, __main__.py: F401
exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py

Expand Down
Loading
Loading