Skip to content

Commit

Permalink
Update torch.load usage to eliminate complaint mesages
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Kerfoot <[email protected]>
  • Loading branch information
ericspod committed Feb 25, 2025
1 parent a7b615e commit 4685eca
Show file tree
Hide file tree
Showing 24 changed files with 47 additions and 41 deletions.
2 changes: 1 addition & 1 deletion monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, images: torch.Tensor):
nesterov=True,
)
torch.save(detector.network.state_dict(), 'model.pt') # save model
detector.network.load_state_dict(torch.load('model.pt')) # load model
detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def load_from_mmar(
return torch.jit.load(_model_file, map_location=map_location)

# loading with `torch.load`
model_dict = torch.load(_model_file, map_location=map_location)
model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
if weights_only:
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly

Expand Down
4 changes: 2 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def load(
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)

if not isinstance(model_dict, Mapping):
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
Expand Down Expand Up @@ -1306,7 +1306,7 @@ def _export(
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
else:
ckpt = torch.load(ckpt_file)
ckpt = torch.load(ckpt_file, weights_only=True)
copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])

# Use the given converter to convert a model and save with metadata, config content
Expand Down
10 changes: 2 additions & 8 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,7 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
if "weights_only" in signature(torch.load).parameters:
return torch.load(hashfile, weights_only=False)
else:
return torch.load(hashfile)
return torch.load(hashfile, weights_only=False)
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand Down Expand Up @@ -1674,7 +1671,4 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
if "weights_only" in signature(torch.load).parameters:
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
else:
return torch.load(self.cache_dir / meta_hash_file_name)
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
2 changes: 1 addition & 1 deletion monai/fl/client/monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def get_weights(self, extra=None):
model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
if not os.path.isfile(model_path):
raise ValueError(f"No best model checkpoint exists at {model_path}")
weights = torch.load(model_path, map_location="cpu")
weights = torch.load(model_path, map_location="cpu", weights_only=True)
# if weights contain several state dicts, use the one defined by `save_dict_key`
if isinstance(weights, dict) and self.save_dict_key in weights:
weights = weights.get(self.save_dict_key)
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
checkpoint = torch.load(self.load_path, map_location=self.map_location)
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)

k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def __init__(
else:
network = torchvision.models.resnet50(weights=None)
if pretrained is True:
state_dict = torch.load(pretrained_path)
state_dict = torch.load(pretrained_path, weights_only=True)
if pretrained_state_dict_key is not None:
state_dict = state_dict[pretrained_state_dict_key]
network.load_state_dict(state_dict)
Expand Down
9 changes: 5 additions & 4 deletions monai/networks/nets/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str):
# download the pretrained weights into torch hub's default dir
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[
"desc"
]
map_location = None if torch.cuda.is_available() else torch.device("cpu")
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"]

for key in list(state_dict.keys()):
new_key = None
if pattern_conv0.match(key):
Expand Down Expand Up @@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No
# download the pretrained weights into torch hub's default dir
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))
map_location = None if torch.cuda.is_available() else torch.device("cpu")
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)
if state_dict_key is not None:
state_dict = state_dict[state_dict_key]

Expand Down
4 changes: 2 additions & 2 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def _resnet(
if isinstance(pretrained, str):
if Path(pretrained).exists():
logger.info(f"Loading weights from {pretrained}...")
model_state_dict = torch.load(pretrained, map_location=device)
model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)
else:
# Throw error
raise FileNotFoundError("The pretrained checkpoint file is not found")
Expand Down Expand Up @@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
raise EntryNotFoundError(
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
) from None
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)
else:
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
logger.info(f"{filename} downloaded")
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool):

if isinstance(model_url, dict):
download_url(model_url["url"], filepath=model_url["filename"])
state_dict = torch.load(model_url["filename"], map_location=None)
state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True)
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ def filter_swinunetr(key, value):
)
ssl_weights_path = "./ssl_pretrained_weights.pth"
download_url(resource, ssl_weights_path)
ssl_weights = torch.load(ssl_weights_path)["model"]
ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]
dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
Expand Down
3 changes: 2 additions & 1 deletion monai/networks/nets/transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def from_pretrained(
weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
map_location = "cpu" if not torch.cuda.is_available() else None
state_dict = torch.load(weights_path, map_location=map_location, weights_only=True)
if from_tf:
return load_tf_weights_in_bert(model, weights_path)
old_keys = []
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any:
fn = self.cached[key]["obj"] # pytype: disable=attribute-error
if not os.path.exists(fn): # pytype: disable=wrong-arg-types
raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
data_obj = torch.load(fn, map_location=lambda storage, location: storage)
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)
# copy back to device if necessary
if "device" in self.cached[key]:
data_obj = data_obj.to(self.cached[key]["device"])
Expand Down
22 changes: 16 additions & 6 deletions tests/bundle/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
with skip_if_downloading_fails():
# download bundle, and load weights from the downloaded path
with tempfile.TemporaryDirectory() as tempdir:
bundle_root = os.path.join(tempdir, bundle_name)
# load weights
weights = load(
name=bundle_name,
Expand All @@ -278,7 +279,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
return_state_dict=True,
)
# prepare network
with open(os.path.join(tempdir, bundle_name, bundle_files[2])) as f:
with open(os.path.join(bundle_root, bundle_files[2])) as f:
net_args = json.load(f)["network_def"]
model_name = net_args["_target_"]
del net_args["_target_"]
Expand All @@ -288,9 +289,13 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
model.eval()

# prepare data and test
input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[4]), map_location=device)
input_tensor = torch.load(
os.path.join(bundle_root, bundle_files[4]), map_location=device, weights_only=True
)
output = model.forward(input_tensor)
expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[3]), map_location=device)
expected_output = torch.load(
os.path.join(bundle_root, bundle_files[3]), map_location=device, weights_only=True
)
assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

# load instantiated model directly and test, since the bundle has been downloaded,
Expand Down Expand Up @@ -350,7 +355,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow_type="train"
)
expected_model = workflow.network_def.to(device)
expected_model.load_state_dict(torch.load(model_path))
expected_model.load_state_dict(torch.load(model_path, weights_only=True))
expected_output = expected_model(input_tensor)
assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

Expand Down Expand Up @@ -378,6 +383,7 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device,
with skip_if_downloading_fails():
# load ts module
with tempfile.TemporaryDirectory() as tempdir:
bundle_root = os.path.join(tempdir, bundle_name)
# load ts module
model_ts, metadata, extra_file_dict = load(
name=bundle_name,
Expand All @@ -393,9 +399,13 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device,
)

# prepare and test ts
input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device)
input_tensor = torch.load(
os.path.join(bundle_root, bundle_files[1]), map_location=device, weights_only=True
)
output = model_ts.forward(input_tensor)
expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[0]), map_location=device)
expected_output = torch.load(
os.path.join(bundle_root, bundle_files[0]), map_location=device, weights_only=True
)
assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
# test metadata
self.assertTrue(metadata["pytorch_version"] == "1.7.1")
Expand Down
2 changes: 1 addition & 1 deletion tests/data/meta_tensor/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_pickling(self):
with tempfile.TemporaryDirectory() as tmp_dir:
fname = os.path.join(tmp_dir, "im.pt")
torch.save(m, fname)
m2 = torch.load(fname)
m2 = torch.load(fname, weights_only=True)
self.check(m2, m, ids=False)

@skip_if_no_cuda
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_integration_classification_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(test_y))).to(device)

model_filename = os.path.join(root_dir, "best_metric_model.pth")
model.load_state_dict(torch.load(model_filename))
model.load_state_dict(torch.load(model_filename, weights_only=True))
y_true = []
y_pred = []
with eval_mode(model):
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_integration_segmentation_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def run_inference_test(root_dir, device="cuda:0"):
).to(device)

model_filename = os.path.join(root_dir, "best_metric_model.pth")
model.load_state_dict(torch.load(model_filename))
model.load_state_dict(torch.load(model_filename, weights_only=True))
with eval_mode(model):
# resampling with align_corners=True or dtype=float64 will generate
# slight different results between PyTorch 1.5 an 1.6
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/nets/test_autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_compatibility_with_monai_generative(self):
weight_path = os.path.join(tmpdir, filename)
download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)

net.load_old_state_dict(torch.load(weight_path), verbose=False)
net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/nets/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_compatibility_with_monai_generative(self):
weight_path = os.path.join(tmpdir, filename)
download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)

net.load_old_state_dict(torch.load(weight_path), verbose=False)
net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/nets/test_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def test_compatibility_with_monai_generative(self):
weight_path = os.path.join(tmpdir, filename)
download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)

net.load_old_state_dict(torch.load(weight_path), verbose=False)
net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/nets/test_network_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_network_consistency(self, net_name, data_path, json_path):
print("JSON path: " + json_path)

# Load data
loaded_data = torch.load(data_path)
loaded_data = torch.load(data_path, weights_only=True)

# Load json from file
json_file = open(json_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/nets/test_swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_filter_swinunetr(self, input_param, key, value):
data_spec["url"], weight_path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"]
)

ssl_weight = torch.load(weight_path)["model"]
ssl_weight = torch.load(weight_path, weights_only=True)["model"]
net = SwinUNETR(**input_param)
dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr)
assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/nets/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_compatibility_with_monai_generative(self):
weight_path = os.path.join(tmpdir, filename)
download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)

net.load_old_state_dict(torch.load(weight_path), verbose=False)
net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/test_save_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_file(self, src, expected_keys, create_dir=True, atomic=True, func=None,
if kwargs is None:
kwargs = {}
save_state(src=src, path=path, create_dir=create_dir, atomic=atomic, func=func, **kwargs)
ckpt = dict(torch.load(path))
ckpt = dict(torch.load(path, weights_only=True))
for k in ckpt.keys():
self.assertIn(k, expected_keys)

Expand Down

0 comments on commit 4685eca

Please sign in to comment.