Official pytorch implementation of the paper: "End-to-End Supermask Pruning: Learning to Prune Image Captioning Models"
Released on July 20, 2021
This work explores model pruning for image captioning task at the first time. Empirically, we show that 80% to 95% sparse networks can either match or even slightly outperform their dense counterparts. In order to promote Green Computer Vision, we release the pre-trained sparse models for UD and ORT that are capable of achieving CIDEr scores >120 on MS-COCO dataset; yet are only 8.7 MB (reduction of 96% compared to dense UD) and 14.5 MB (reduction of 94% compared to dense ORT) in model size.
Figure 1: We show that our deep captioning networks with 80% to 95% sparse are capable of either matching or even slightly outperforming their dense counterparts.
Please refer to the documentation.
- Captioning models built using PyTorch
- Up-Down LSTM
- Object Relation Transformer
- A Compact Object Relation Transformer (ACORT)
- Unstructured weight pruning
- Supermask Pruning (SMP, end-to-end pruning)
- Gradual magnitude pruning
- Lottery ticket
- One-shot magnitude pruning (paper 1, paper 2)
- Single-shot Network Pruning (SNIP)
- Self-Critical Sequence Training (SCST)]()
- Random sampling + Greedy search baseline: vanilla SCST
- Beam search sampling + Greedy search baseline: à la Up-Down
- Random sampling + Sample mean baseline: arxiv paper
- Beam search sampling + Sample mean baseline: à la M2 Transformer
- Optimise CIDEr and/or BLEU scores with custom weightage
- Based on ruotianluo/self-critical.pytorch
- Multiple captions per image during teacher-forcing training
- Reduce training time: run encoder once, optimize on multiple training captions
- Incremental decoding (Transformer with attention cache)
- Data caching during training
- Training examples will be cached in memory to reduce disk I/O
- With sufficient memory, the entire training set can be loaded from memory after the first epoch
- Memory usage can be controlled via
cache_min_free_ram
flag
coco_caption
in Python 3- Based on salaniz/pycocoevalcap
- Tokenizer based on
sentencepiece
- Word
- Radix encoding
- (untested) Unigram, BPE, Character
- Datasets
- MS-COCO
- (contributions welcome) Flickr8k, Flickr30k, InstaPIC-1.1M
The checkpoints are available at this repo.
Soft-attention models implemented in TensorFlow 1.9 are available at this repo.
Sparsity | NNZ | Dense Baseline | SMP | Lottery ticket (class-blind) | Lottery ticket (class-uniform) | Lottery ticket (gradual) | Gradual pruning | Hard pruning (class-blind) | Hard pruning (class-distribution) | Hard pruning (class-uniform) | SNIP |
---|---|---|---|---|---|---|---|---|---|---|---|
0.950 | 2.7 M | 111.3 | 112.5 | - | 107.7 | 109.5 | 109.7 | - | 110.0 | 110.2 | 38.2 |
0.975 | 1.3 M | 111.3 | 110.6 | - | 103.8 | 106.6 | 107.0 | - | 105.9 | 105.4 | 34.7 |
0.988 | 0.7 M | 111.3 | 109.0 | - | 99.3 | 102.2 | 103.4 | - | 101.3 | 100.5 | 32.6 |
0.991 | 0.5 M | 111.3 | 107.8 |
Sparsity | NNZ | Dense Baseline | SMP | Lottery ticket (gradual) | Gradual pruning | Hard pruning (class-blind) | Hard pruning (class-distribution) | Hard pruning (class-uniform) | SNIP |
---|---|---|---|---|---|---|---|---|---|
0.950 | 2.8 M | 114.7 | 113.7 | 115.7 | 115.3 | 4.1 | 112.5 | 113.0 | 47.2 |
0.975 | 1.4 M | 114.7 | 113.7 | 112.9 | 113.2 | 0.7 | 106.6 | 106.9 | 44.0 |
0.988 | 0.7 M | 114.7 | 110.7 | 109.8 | 110.0 | 0.9 | 96.9 | 59.8 | 37.3 |
0.991 | 0.5 M | 114.7 | 109.3 | 107.1 | 107.0 |
- SCST, Up-Down: ruotianluo/self-critical.pytorch
- Object Relation Transformer: yahoo/object_relation_transformer
coco_caption
in Python 3: salaniz/pycocoevalcap
If you find this work useful for your research, please cite
@article{tan2021end,
title={End-to-End Supermask Pruning: Learning to Prune Image Captioning Models},
author={Tan, Jia Huei and Chan, Chee Seng and Chuah, Joon Huang},
journal={Pattern Recognition},
pages={108366},
year={2021},
publisher={Elsevier},
doi={10.1016/j.patcog.2021.108366}
}
Suggestions and opinions on this work (both positive and negative) are greatly welcomed. Please contact the authors by sending an email to
tan.jia.huei at gmail.com
or cs.chan at um.edu.my
.
The project is open source under BSD-3 license (see the LICENSE
file).
©2021 Universiti Malaya.
Run Black linting:
black --line-length=120 --safe sparse_caption
black --line-length=120 --safe tests
black --line-length=120 --safe scripts