Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model-zoo error about Breast density classification execute reasoning #606

Open
shihzenq opened this issue Jul 25, 2024 · 5 comments
Open

Comments

@shihzenq
Copy link

shihzenq commented Jul 25, 2024

hello, My question is as follows
image
I have downloaded model.pth from the address。

code:

`import glob
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.models import inception_v3
from monai.bundle import download, load

data_dir = 'breast_density_classification/sample_data'

test_images = sorted(glob.glob(os.path.join(data_dir, "A", "*.jpg")))

preprocess = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_path = "breast_density_classification/models/model.pth"

model = inception_v3(pretrained=False, aux_logits=False, num_classes=4).to(device)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

for img_path in test_images:
img = Image.open(img_path).convert('RGB')
img_tensor = preprocess(img).unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model(img_tensor)
    probs = torch.nn.functional.softmax(outputs, dim=1)
    pred_class = torch.argmax(probs, dim=1).item()

plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title(f'Predicted Class: {pred_class}')
plt.axis('off')
plt.show()`

error:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Inception3:
Missing key(s) in state_dict: "Conv2d_1a_3x3.conv.weight", "Conv2d_1a_3x3.bn.weight",

This error indicates that the loaded state_dict does not match the defined InceptionV3 model. Did I do something wrong.
This is an urgent problem. thank you

@yiheng-wang-nv
Copy link
Collaborator

It may due to the torchvision you imported mismatches the bundle'v version.
Did you use 0.14.1 ?
https://github.com/Project-MONAI/model-zoo/blob/dev/models/breast_density_classification/configs/metadata.json#L17

@shihzenq
Copy link
Author

thank you, It's the version. Reasoning is good for now。
I want to know how to training, I found no training in the https://arxiv.org/abs/2202.08238 this link and data in where。
image

@yiheng-wang-nv
Copy link
Collaborator

Hi @shihzenq , for algorithm side questions, I would suggest to ask the bundle author:
https://github.com/Project-MONAI/model-zoo/tree/dev/models/breast_density_classification#contributors

Hi @vikashg , could you provide some information here? Thanks!

@shihzenq
Copy link
Author

@yiheng-wang-nv thank you

@vikashg
Copy link
Contributor

vikashg commented Jul 29, 2024

Hi @shihzenq,
Thanks @yiheng-wang-nv. Sorry for the late response. I hope the torchvision problem is sorted at this point.
We at MayoClinic did not share the training code through the monai model zoo. But you can write your own training loop.
It is basically fine tuning the inception-v3 model, nothing particularly fancy here. Also, the dataset was not made public as the data set belongs to Mayo Clinic. You can try using any publicly available datasets (for example the challenge data from a recent Federated Learning competition held by NVIDIA). The data will be available on Kaggle I think.
Hope this helps.
Vikash

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants