diff --git a/demo/predictor.py b/demo/predictor.py index fa663c7e4..95b0a5c3b 100644 --- a/demo/predictor.py +++ b/demo/predictor.py @@ -10,6 +10,7 @@ from maskrcnn_benchmark import layers as L from maskrcnn_benchmark.utils import cv2_util + class Resize(object): def __init__(self, min_size, max_size): self.min_size = min_size @@ -42,6 +43,8 @@ def __call__(self, image): size = self.get_size(image.size) image = F.resize(image, size) return image + + class COCODemo(object): # COCO categories for pretty print CATEGORIES = [ @@ -197,9 +200,8 @@ def run_on_opencv_image(self, image): image (np.ndarray): an image as returned by OpenCV Returns: - prediction (BoxList): the detected objects. Additional information - of the detection properties can be found in the fields of - the BoxList via `prediction.fields()` + result (np.ndarray): an image with detected results + (boxes, masks, keypoints, etc) shown on top of it. """ predictions = self.compute_prediction(image) top_predictions = self.select_top_predictions(predictions) @@ -327,9 +329,7 @@ def overlay_mask(self, image, predictions): ) image = cv2.drawContours(image, contours, -1, color, 3) - composite = image - - return composite + return image def overlay_keypoints(self, image, predictions): keypoints = predictions.get_field("keypoints") @@ -342,7 +342,7 @@ def overlay_keypoints(self, image, predictions): def create_mask_montage(self, image, predictions): """ - Create a montage showing the probability heatmaps for each one one of the + Create a montage showing the probability heatmaps for each one of the detected objects Arguments: diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index 65fbdaddd..4dcca016a 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -186,8 +186,8 @@ _C.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS = (10., 10., 5., 5.) # RoI minibatch size *per image* (number of regions of interest [ROIs]) # Total number of RoIs per training minibatch = -# TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH -# E.g., a common configuration is: 512 * 2 * 8 = 8192 +# MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH +# E.g., a common configuration is: 512 * 16 = 8192 _C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 # Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) _C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25 diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py index 26239155d..afd4fe4fe 100644 --- a/maskrcnn_benchmark/data/build.py +++ b/maskrcnn_benchmark/data/build.py @@ -123,7 +123,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format( images_per_batch, num_gpus) images_per_gpu = images_per_batch // num_gpus - shuffle = False if not is_distributed else True + shuffle = True if is_distributed else False num_iters = None start_iter = 0 diff --git a/maskrcnn_benchmark/data/datasets/coco.py b/maskrcnn_benchmark/data/datasets/coco.py index cc10f29d1..e189b25b2 100644 --- a/maskrcnn_benchmark/data/datasets/coco.py +++ b/maskrcnn_benchmark/data/datasets/coco.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -import torchvision +from torchvision.datasets.coco import CocoDetection from maskrcnn_benchmark.structures.bounding_box import BoxList from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask from maskrcnn_benchmark.structures.keypoint import PersonKeypoints @@ -25,18 +25,18 @@ def has_valid_annotation(anno): # if all boxes have close to zero area, there is no annotation if _has_only_empty_bbox(anno): return False - # keypoints task have a slight different critera for considering + # keypoints task has a slight different criteria for considering # if an annotation is valid if "keypoints" not in anno[0]: return True - # for keypoint detection tasks, only consider valid images those + # for keypoint detection task, only consider valid images those # containing at least min_keypoints_per_image if _count_visible_keypoints(anno) >= min_keypoints_per_image: return True return False -class COCODataset(torchvision.datasets.coco.CocoDetection): +class COCODataset(CocoDetection): def __init__( self, ann_file, root, remove_images_without_annotations, transforms=None ): diff --git a/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py index a8fdc280e..f97cf71a5 100644 --- a/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py +++ b/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py @@ -4,6 +4,7 @@ import torch from collections import OrderedDict from tqdm import tqdm +from pycocotools.cocoeval import COCOeval from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker from maskrcnn_benchmark.structures.bounding_box import BoxList @@ -45,9 +46,9 @@ def do_coco_evaluation( if "segm" in iou_types: logger.info("Preparing segm results") coco_results["segm"] = prepare_for_coco_segmentation(predictions, dataset) - if 'keypoints' in iou_types: - logger.info('Preparing keypoints results') - coco_results['keypoints'] = prepare_for_coco_keypoint(predictions, dataset) + if "keypoints" in iou_types: + logger.info("Preparing keypoints results") + coco_results["keypoints"] = prepare_for_coco_keypoint(predictions, dataset) results = COCOResults(*iou_types) logger.info("Evaluating predictions") @@ -68,7 +69,6 @@ def do_coco_evaluation( def prepare_for_coco_detection(predictions, dataset): - # assert isinstance(dataset, COCODataset) coco_results = [] for image_id, prediction in enumerate(predictions): original_id = dataset.id_to_img_map[image_id] @@ -106,7 +106,6 @@ def prepare_for_coco_segmentation(predictions, dataset): import numpy as np masker = Masker(threshold=0.5, padding=1) - # assert isinstance(dataset, COCODataset) coco_results = [] for image_id, prediction in tqdm(enumerate(predictions)): original_id = dataset.id_to_img_map[image_id] @@ -118,20 +117,15 @@ def prepare_for_coco_segmentation(predictions, dataset): image_height = img_info["height"] prediction = prediction.resize((image_width, image_height)) masks = prediction.get_field("mask") - # t = time.time() + # Masker is necessary only if masks haven't been already resized. if list(masks.shape[-2:]) != [image_height, image_width]: masks = masker(masks.expand(1, -1, -1, -1, -1), prediction) masks = masks[0] - # logger.info('Time mask: {}'.format(time.time() - t)) - # prediction = prediction.convert('xywh') - # boxes = prediction.bbox.tolist() scores = prediction.get_field("scores").tolist() labels = prediction.get_field("labels").tolist() - # rles = prediction.get_field('mask') - rles = [ mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] for mask in masks @@ -156,33 +150,36 @@ def prepare_for_coco_segmentation(predictions, dataset): def prepare_for_coco_keypoint(predictions, dataset): - # assert isinstance(dataset, COCODataset) coco_results = [] for image_id, prediction in enumerate(predictions): original_id = dataset.id_to_img_map[image_id] if len(prediction.bbox) == 0: continue - # TODO replace with get_img_info? - image_width = dataset.coco.imgs[original_id]['width'] - image_height = dataset.coco.imgs[original_id]['height'] + img_info = dataset.get_img_info(image_id) + image_width = img_info["width"] + image_height = img_info["height"] prediction = prediction.resize((image_width, image_height)) - prediction = prediction.convert('xywh') - boxes = prediction.bbox.tolist() - scores = prediction.get_field('scores').tolist() - labels = prediction.get_field('labels').tolist() - keypoints = prediction.get_field('keypoints') + scores = prediction.get_field("scores").tolist() + labels = prediction.get_field("labels").tolist() + keypoints = prediction.get_field("keypoints") keypoints = keypoints.resize((image_width, image_height)) keypoints = keypoints.keypoints.view(keypoints.keypoints.shape[0], -1).tolist() mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels] - coco_results.extend([{ - 'image_id': original_id, - 'category_id': mapped_labels[k], - 'keypoints': keypoint, - 'score': scores[k]} for k, keypoint in enumerate(keypoints)]) + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": mapped_labels[k], + "keypoints": keypoint, + "score": scores[k] + } + for k, keypoint in enumerate(keypoints) + ] + ) return coco_results # inspired from Detectron @@ -311,11 +308,8 @@ def evaluate_predictions_on_coco( json.dump(coco_results, f) from pycocotools.coco import COCO - from pycocotools.cocoeval import COCOeval coco_dt = coco_gt.loadRes(str(json_result_file)) if coco_results else COCO() - - # coco_dt = coco_gt.loadRes(coco_results) coco_eval = COCOeval(coco_gt, coco_dt, iou_type) coco_eval.evaluate() coco_eval.accumulate() @@ -353,7 +347,6 @@ def __init__(self, *iou_types): def update(self, coco_eval): if coco_eval is None: return - from pycocotools.cocoeval import COCOeval assert isinstance(coco_eval, COCOeval) s = coco_eval.stats diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py index 482081b8d..8d9f69912 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py @@ -1,6 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -from torch import nn from .roi_box_feature_extractors import make_roi_box_feature_extractor from .roi_box_predictors import make_roi_box_predictor diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py index cc2f4fa85..91b803dd9 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py @@ -48,7 +48,7 @@ def forward(self, x, boxes): x (tuple[tensor, tensor]): x contains the class logits and the box_regression from the model. boxes (list[BoxList]): bounding boxes that are used as - reference, one for ech image + reference, one for each image Returns: results (list[BoxList]): one BoxList for each image, containing @@ -60,7 +60,7 @@ def forward(self, x, boxes): # TODO think about a representation of batch of boxes image_shapes = [box.size for box in boxes] boxes_per_image = [len(box) for box in boxes] - concat_boxes = torch.cat([a.bbox for a in boxes], dim=0) + concat_boxes = torch.cat([box.bbox for box in boxes], dim=0) if self.cls_agnostic_bbox_reg: box_regression = box_regression[:, -4:] @@ -150,8 +150,6 @@ def filter_results(self, boxlist, num_classes): def make_roi_box_post_processor(cfg): - use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN - bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS box_coder = BoxCoder(weights=bbox_reg_weights) diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py index e47714746..eeda59dce 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py @@ -1,5 +1,4 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -import torch from torch import nn from torch.nn import functional as F @@ -74,10 +73,8 @@ def __init__(self, cfg, in_channels): def forward(self, x, proposals): x = self.pooler(x, proposals) x = x.view(x.size(0), -1) - x = F.relu(self.fc6(x)) x = F.relu(self.fc7(x)) - return x @@ -106,7 +103,7 @@ def __init__(self, cfg, in_channels): dilation = cfg.MODEL.ROI_BOX_HEAD.DILATION xconvs = [] - for ix in range(num_stacked_convs): + for _ in range(num_stacked_convs): xconvs.append( nn.Conv2d( in_channels, @@ -127,9 +124,9 @@ def __init__(self, cfg, in_channels): for modules in [self.xconvs,]: for l in modules.modules(): if isinstance(l, nn.Conv2d): - torch.nn.init.normal_(l.weight, std=0.01) + nn.init.normal_(l.weight, std=0.01) if not use_gn: - torch.nn.init.constant_(l.bias, 0) + nn.init.constant_(l.bias, 0) input_size = conv_head_dim * resolution ** 2 representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py index 66ee4ace5..3b9b1e955 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py @@ -9,19 +9,16 @@ def __init__(self, config, in_channels): super(FastRCNNPredictor, self).__init__() assert in_channels is not None - num_inputs = in_channels - num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES self.avgpool = nn.AdaptiveAvgPool2d(1) - self.cls_score = nn.Linear(num_inputs, num_classes) + self.cls_score = nn.Linear(in_channels, num_classes) num_bbox_reg_classes = 2 if config.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes - self.bbox_pred = nn.Linear(num_inputs, num_bbox_reg_classes * 4) + self.bbox_pred = nn.Linear(in_channels, num_bbox_reg_classes * 4) nn.init.normal_(self.cls_score.weight, mean=0, std=0.01) - nn.init.constant_(self.cls_score.bias, 0) - nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001) - nn.init.constant_(self.bbox_pred.bias, 0) + for l in [self.cls_score, self.bbox_pred]: + nn.init.constant_(l.bias, 0) def forward(self, x): x = self.avgpool(x) @@ -36,11 +33,10 @@ class FPNPredictor(nn.Module): def __init__(self, cfg, in_channels): super(FPNPredictor, self).__init__() num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES - representation_size = in_channels - self.cls_score = nn.Linear(representation_size, num_classes) + self.cls_score = nn.Linear(in_channels, num_classes) num_bbox_reg_classes = 2 if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes - self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4) + self.bbox_pred = nn.Linear(in_channels, num_bbox_reg_classes * 4) nn.init.normal_(self.cls_score.weight, std=0.01) nn.init.normal_(self.bbox_pred.weight, std=0.001) diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py index 5a842cad3..078fbfc2a 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py +++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py @@ -9,7 +9,6 @@ class ROIKeypointHead(torch.nn.Module): def __init__(self, cfg, in_channels): super(ROIKeypointHead, self).__init__() - self.cfg = cfg.clone() self.feature_extractor = make_roi_keypoint_feature_extractor(cfg, in_channels) self.predictor = make_roi_keypoint_predictor( cfg, self.feature_extractor.out_channels) @@ -27,10 +26,11 @@ def forward(self, features, proposals, targets=None): x (Tensor): the result of the feature extractor proposals (list[BoxList]): during training, the original proposals are returned. During testing, the predicted boxlists are returned - with the `mask` field set + with the `keypoints` field set losses (dict[Tensor]): During training, returns the losses for the head. During testing, returns an empty dict. """ + if self.training: with torch.no_grad(): proposals = self.loss_evaluator.subsample(proposals, targets) diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py index 7193efc25..e820509b1 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py +++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py @@ -1,6 +1,7 @@ from torch import nn -from maskrcnn_benchmark import layers +from maskrcnn_benchmark.layers import ConvTranspose2d +from maskrcnn_benchmark.layers import interpolate from maskrcnn_benchmark.modeling import registry @@ -8,11 +9,10 @@ class KeypointRCNNPredictor(nn.Module): def __init__(self, cfg, in_channels): super(KeypointRCNNPredictor, self).__init__() - input_features = in_channels num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES deconv_kernel = 4 - self.kps_score_lowres = layers.ConvTranspose2d( - input_features, + self.kps_score_lowres = ConvTranspose2d( + in_channels, num_keypoints, deconv_kernel, stride=2, @@ -27,7 +27,7 @@ def __init__(self, cfg, in_channels): def forward(self, x): x = self.kps_score_lowres(x) - x = layers.interpolate( + x = interpolate( x, scale_factor=self.up_scale, mode="bilinear", align_corners=False ) return x diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py index bd831c085..cb6b0099f 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py @@ -3,7 +3,6 @@ import torch from torch import nn from maskrcnn_benchmark.layers.misc import interpolate - from maskrcnn_benchmark.structures.bounding_box import BoxList @@ -29,7 +28,7 @@ def forward(self, x, boxes): Arguments: x (Tensor): the mask logits boxes (list[BoxList]): bounding boxes that are used as - reference, one for ech image + reference, one for each image Returns: results (list[BoxList]): one BoxList for each image, containing @@ -70,7 +69,6 @@ class MaskPostProcessorCOCOFormat(MaskPostProcessor): def forward(self, x, boxes): import pycocotools.mask as mask_util - import numpy as np results = super(MaskPostProcessorCOCOFormat, self).forward(x, boxes) for result in results: @@ -87,7 +85,7 @@ def forward(self, x, boxes): # the next two functions should be merged inside Masker # but are kept here for the moment while we need them -# temporarily gor paste_mask_in_image +# temporarily for paste_mask_in_image def expand_boxes(boxes, scale): w_half = (boxes[:, 2] - boxes[:, 0]) * .5 h_half = (boxes[:, 3] - boxes[:, 1]) * .5 @@ -127,10 +125,8 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): box = box.to(dtype=torch.int32) TO_REMOVE = 1 - w = int(box[2] - box[0] + TO_REMOVE) - h = int(box[3] - box[1] + TO_REMOVE) - w = max(w, 1) - h = max(h, 1) + w = max(int(box[2] - box[0] + TO_REMOVE), 1) + h = max(int(box[3] - box[1] + TO_REMOVE), 1) # Set shape to [batchxCxHxW] mask = mask.expand((1, 1, -1, -1)) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py index d4c5e3621..21847da1b 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py @@ -19,6 +19,7 @@ def project_masks_on_boxes(segmentation_masks, proposals, discretization_size): Arguments: segmentation_masks: an instance of SegmentationMask proposals: an instance of BoxList + discretization_size: spatial resolution of masks """ masks = [] M = discretization_size @@ -33,9 +34,9 @@ def project_masks_on_boxes(segmentation_masks, proposals, discretization_size): for segmentation_mask, proposal in zip(segmentation_masks, proposals): # crop the masks, resize them to the desired resolution and # then convert them to the tensor representation. - cropped_mask = segmentation_mask.crop(proposal) - scaled_mask = cropped_mask.resize((M, M)) - mask = scaled_mask.get_mask_tensor() + mask = segmentation_mask.crop(proposal) + mask = mask.resize((M, M)) + mask = mask.get_mask_tensor() masks.append(mask) if len(masks) == 0: return torch.empty(0, dtype=torch.float32, device=device) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py index a9ce245b6..f4deaa0f7 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py @@ -1,6 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -from torch import nn from maskrcnn_benchmark.structures.bounding_box import BoxList @@ -21,9 +20,9 @@ def keep_only_positive_boxes(boxes): assert isinstance(boxes, (list, tuple)) assert isinstance(boxes[0], BoxList) assert boxes[0].has_field("labels") + positive_boxes = [] positive_inds = [] - num_boxes = 0 for boxes_per_image in boxes: labels = boxes_per_image.get_field("labels") inds_mask = labels > 0 diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py index 117edc4cc..bd57e6ede 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py @@ -16,16 +16,10 @@ @registry.ROI_MASK_FEATURE_EXTRACTORS.register("MaskRCNNFPNFeatureExtractor") class MaskRCNNFPNFeatureExtractor(nn.Module): """ - Heads for FPN for classification + Heads for FPN for segmentation """ def __init__(self, cfg, in_channels): - """ - Arguments: - num_classes (int): number of output classes - input_size (int): number of channels of the input once it's flattened - representation_size (int): size of the intermediate representation - """ super(MaskRCNNFPNFeatureExtractor, self).__init__() resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION @@ -36,14 +30,13 @@ def __init__(self, cfg, in_channels): scales=scales, sampling_ratio=sampling_ratio, ) - input_size = in_channels self.pooler = pooler use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS dilation = cfg.MODEL.ROI_MASK_HEAD.DILATION - next_feature = input_size + next_feature = in_channels self.blocks = [] for layer_idx, layer_features in enumerate(layers, 1): layer_name = "mask_fcn{}".format(layer_idx) @@ -58,7 +51,6 @@ def __init__(self, cfg, in_channels): def forward(self, x, proposals): x = self.pooler(x, proposals) - for layer_name in self.blocks: x = F.relu(getattr(self, layer_name)(x)) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py index c954e332e..433f81470 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py @@ -11,11 +11,10 @@ class MaskRCNNC4Predictor(nn.Module): def __init__(self, cfg, in_channels): super(MaskRCNNC4Predictor, self).__init__() + num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] - num_inputs = in_channels - - self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + self.conv5_mask = ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0) self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) for name, param in self.named_parameters(): @@ -35,10 +34,9 @@ def forward(self, x): class MaskRCNNConv1x1Predictor(nn.Module): def __init__(self, cfg, in_channels): super(MaskRCNNConv1x1Predictor, self).__init__() - num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES - num_inputs = in_channels - self.mask_fcn_logits = Conv2d(num_inputs, num_classes, 1, 1, 0) + num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES + self.mask_fcn_logits = Conv2d(in_channels, num_classes, 1, 1, 0) for name, param in self.named_parameters(): if "bias" in name: diff --git a/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py b/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py index 99ed7b981..3a575bb9f 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py +++ b/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py @@ -16,8 +16,10 @@ def __init__(self, cfg, heads): super(CombinedROIHeads, self).__init__(heads) self.cfg = cfg.clone() if cfg.MODEL.MASK_ON and cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: + assert cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION == cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION self.mask.feature_extractor = self.box.feature_extractor if cfg.MODEL.KEYPOINT_ON and cfg.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: + assert cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION == cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION self.keypoint.feature_extractor = self.box.feature_extractor def forward(self, features, proposals, targets=None): @@ -34,7 +36,7 @@ def forward(self, features, proposals, targets=None): and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR ): mask_features = x - # During training, self.box() will return the unaltered proposals as "detections" + # During training, self.mask() will return the unaltered proposals as "detections" # this makes the API consistent during training and testing x, detections, loss_mask = self.mask(mask_features, detections, targets) losses.update(loss_mask) @@ -42,16 +44,17 @@ def forward(self, features, proposals, targets=None): if self.cfg.MODEL.KEYPOINT_ON: keypoint_features = features # optimization: during training, if we share the feature extractor between - # the box and the mask heads, then we can reuse the features already computed + # the box and the keypoint heads, then we can reuse the features already computed if ( self.training and self.cfg.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR ): keypoint_features = x - # During training, self.box() will return the unaltered proposals as "detections" + # During training, self.keypoint() will return the unaltered proposals as "detections" # this makes the API consistent during training and testing x, detections, loss_keypoint = self.keypoint(keypoint_features, detections, targets) losses.update(loss_keypoint) + return x, detections, losses diff --git a/maskrcnn_benchmark/modeling/rpn/rpn.py b/maskrcnn_benchmark/modeling/rpn/rpn.py index 07997651c..03af3f822 100644 --- a/maskrcnn_benchmark/modeling/rpn/rpn.py +++ b/maskrcnn_benchmark/modeling/rpn/rpn.py @@ -30,8 +30,8 @@ def __init__(self, cfg, in_channels, num_anchors): ) for l in [self.cls_logits, self.bbox_pred]: - torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.constant_(l.bias, 0) + nn.init.normal_(l.weight, std=0.01) + nn.init.constant_(l.bias, 0) def forward(self, x): assert isinstance(x, (list, tuple)) @@ -58,8 +58,8 @@ def __init__(self, cfg, in_channels): ) for l in [self.conv]: - torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.constant_(l.bias, 0) + nn.init.normal_(l.weight, std=0.01) + nn.init.constant_(l.bias, 0) self.out_channels = in_channels @@ -93,8 +93,8 @@ def __init__(self, cfg, in_channels, num_anchors): ) for l in [self.conv, self.cls_logits, self.bbox_pred]: - torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.constant_(l.bias, 0) + nn.init.normal_(l.weight, std=0.01) + nn.init.constant_(l.bias, 0) def forward(self, x): logits = [] @@ -106,7 +106,7 @@ def forward(self, x): return logits, bbox_reg -class RPNModule(torch.nn.Module): +class RPNModule(nn.Module): """ Module for RPN computation. Takes feature maps from the backbone and RPN proposals and losses. Works for both FPN and non-FPN. @@ -117,25 +117,16 @@ def __init__(self, cfg, in_channels): self.cfg = cfg.clone() - anchor_generator = make_anchor_generator(cfg) - + self.anchor_generator = make_anchor_generator(cfg) rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD] - head = rpn_head( - cfg, in_channels, anchor_generator.num_anchors_per_location()[0] + self.head = rpn_head( + cfg, in_channels, self.anchor_generator.num_anchors_per_location()[0] ) rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) - - box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True) - box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False) - - loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder) - - self.anchor_generator = anchor_generator - self.head = head - self.box_selector_train = box_selector_train - self.box_selector_test = box_selector_test - self.loss_evaluator = loss_evaluator + self.box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True) + self.box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False) + self.loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder) def forward(self, images, features, targets=None): """ diff --git a/maskrcnn_benchmark/structures/bounding_box.py b/maskrcnn_benchmark/structures/bounding_box.py index 25791d578..9aea6b925 100644 --- a/maskrcnn_benchmark/structures/bounding_box.py +++ b/maskrcnn_benchmark/structures/bounding_box.py @@ -101,7 +101,6 @@ def resize(self, size, *args, **kwargs): ratio = ratios[0] scaled_box = self.bbox * ratio bbox = BoxList(scaled_box, size, mode=self.mode) - # bbox._copy_extra_fields(self) for k, v in self.extra_fields.items(): if not isinstance(v, torch.Tensor): v = v.resize(size, *args, **kwargs) @@ -118,7 +117,6 @@ def resize(self, size, *args, **kwargs): (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 ) bbox = BoxList(scaled_box, size, mode="xyxy") - # bbox._copy_extra_fields(self) for k, v in self.extra_fields.items(): if not isinstance(v, torch.Tensor): v = v.resize(size, *args, **kwargs) @@ -157,7 +155,6 @@ def transpose(self, method): (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 ) bbox = BoxList(transposed_boxes, self.size, mode="xyxy") - # bbox._copy_extra_fields(self) for k, v in self.extra_fields.items(): if not isinstance(v, torch.Tensor): v = v.transpose(method) @@ -185,7 +182,6 @@ def crop(self, box): (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1 ) bbox = BoxList(cropped_box, (w, h), mode="xyxy") - # bbox._copy_extra_fields(self) for k, v in self.extra_fields.items(): if not isinstance(v, torch.Tensor): v = v.crop(box) diff --git a/maskrcnn_benchmark/structures/image_list.py b/maskrcnn_benchmark/structures/image_list.py index 590b87a65..6fda9ca2c 100644 --- a/maskrcnn_benchmark/structures/image_list.py +++ b/maskrcnn_benchmark/structures/image_list.py @@ -49,8 +49,8 @@ def to_image_list(tensors, size_divisible=0): elif isinstance(tensors, (tuple, list)): max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) - # TODO Ideally, just remove this and let me model handle arbitrary - # input sizs + # TODO Ideally, just remove this and let the model handle arbitrary + # input sizes if size_divisible > 0: import math diff --git a/maskrcnn_benchmark/structures/keypoint.py b/maskrcnn_benchmark/structures/keypoint.py index a6881f72f..70fe1d987 100644 --- a/maskrcnn_benchmark/structures/keypoint.py +++ b/maskrcnn_benchmark/structures/keypoint.py @@ -11,9 +11,9 @@ def __init__(self, keypoints, size, mode=None): # in my version this would consistently return a CPU tensor device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device('cpu') keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device) - num_keypoints = keypoints.shape[0] - if num_keypoints: - keypoints = keypoints.view(num_keypoints, -1, 3) + num_persons = keypoints.shape[0] + if num_persons: + keypoints = keypoints.view(num_persons, -1, 3) # TODO should I split them? # self.visibility = keypoints[..., 2]