Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix visualization when greyscale image as input to ImageFeature class #1044

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ commands:
sudo apt-key add /var/cuda-repo-ubuntu2004-11-4-local/7fa2af80.pub
sudo apt-get update
sudo dpkg --configure -a
sudo apt-get --yes --force-yes install cuda
sudo apt-get --yes --allow-downgrades --allow-remove-essential --allow-change-held-packages install cuda

jobs:

Expand Down
20 changes: 14 additions & 6 deletions captum/insights/attr_vis/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from io import BytesIO
from typing import Callable, List, Optional, Union

import numpy as np
from captum._utils.common import safe_div
from captum.attr._utils import visualization as viz
from captum.insights.attr_vis._utils.transforms import format_transforms
Expand Down Expand Up @@ -117,12 +118,19 @@ def visualization_type() -> str:
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
if self.visualization_transform:
data = self.visualization_transform(data)

data_t, attribution_t = [
t.detach().squeeze().permute((1, 2, 0)).cpu().numpy()
for t in (data, attribution)
]

# [N, C, H, W] if C==3, its expected to be in RGB format
if data.shape[:-2][-1] == 3:
data_t, attribution_t = [
t.detach().squeeze().permute((1, 2, 0)).cpu().numpy()
for t in (data, attribution)
]
# [N, C, H, W] if C==1, its assumed to be a greyscale image
if data.shape[:-2][-1] == 1:
data_t, attribution_t = [
t.detach().squeeze().cpu().numpy() for t in (data, attribution)
]
data_t = np.expand_dims(data_t, axis=-1)
attribution_t = np.expand_dims(attribution_t, axis=-1)
orig_fig, _ = viz.visualize_image_attr(
attribution_t, data_t, method="original_image", use_pyplot=False
)
Expand Down