diff --git a/auto3dseg/algorithm_templates/dints/scripts/train.py b/auto3dseg/algorithm_templates/dints/scripts/train.py index aa322069..06d9392d 100644 --- a/auto3dseg/algorithm_templates/dints/scripts/train.py +++ b/auto3dseg/algorithm_templates/dints/scripts/train.py @@ -27,15 +27,11 @@ import mlflow import mlflow.pytorch +import monai import numpy as np import torch import torch.distributed as dist import yaml -from torch.nn.parallel import DistributedDataParallel -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm - -import monai from monai import transforms from monai.apps.auto3dseg.auto_runner import logger from monai.apps.utils import DEFAULT_FMT @@ -45,8 +41,11 @@ from monai.data import DataLoader, partition_dataset from monai.inferers import sliding_window_inference from monai.metrics import compute_dice -from monai.networks.utils import pytorch_after from monai.utils import RankFilter, set_determinism +from monai.utils.module import pytorch_after +from torch.nn.parallel import DistributedDataParallel +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm try: from apex.contrib.clip_grad import clip_grad_norm_