Unofficial implementation of Self-Critical Sequence Training (SCST) and various multi-head attention mechanisms.
This repo contains experimental and unofficial implementation of image captioning frameworks including:
- Self-Critical Sequence Training (SCST) [arxiv]
- Sampling is done via beam search [arxiv]
- Multi-Head Visual Attention
- Graph-based Beam Search, Greedy Search and Sampling
The features might not be completely tested. For a more stable implementation, please refer to this repo.
- tensorflow 1.9.0
- python 2.7
- java 1.8.0
- tqdm >= 4.24.0
- Pillow >= 3.1.2
- requests >= 2.18.4
More examples are given in example.sh
.
Run ./src/setup.sh
. This will download the required Stanford models
and run all the dataset pre-processing scripts.
The training scheme is as follows:
- Start with
decoder
mode (freezing the CNN) - Followed by
cnn_finetune
mode - Finally,
scst
mode
# MS-COCO
for mode in 'decoder' 'cnn_finetune' 'scst'
do
python train.py \
--train_mode ${mode} \
--token_type 'word' \
--cnn_fm_projection 'tied' \
--attn_num_heads 8
done
# InstaPIC
for mode in 'decoder' 'cnn_finetune' 'scst'
do
python train.py \
--train_mode ${mode} \
--dataset_file_pattern 'insta_{}_v25595_s15' \
--token_type 'word' \
--cnn_fm_projection 'independent' \
--attn_num_heads 8
done
Just point infer.py
to the directory containing the checkpoints.
Model configurations are loaded from config.pkl
.
# MS-COCO
python infer.py \
--infer_checkpoints_dir 'mscoco/word_add_softmax_h8_tie_lstm_run_01'
# InstaPIC
python infer.py \
--infer_checkpoints_dir 'insta/word_add_softmax_h8_ind_lstm_run_01' \
--dataset_dir '/path/to/insta/dataset' \
--annotations_file 'insta_testval_raw.json'
- Main:
train_mode
: The training regime. Choices aredecoder
,cnn_finetune
,scst
.token_type
: Language model. Choices areword
,radix
,char
.
- CNN:
cnn_name
: CNN model name.cnn_input_size
: CNN input size.cnn_fm_attention
: End point name of feature map for attention.cnn_fm_projection
: Feature map projection method. Choices arenone
,independent
,tied
.
- RNN:
rnn_name
: Type of RNN. Choices areLSTM
,LN_LSTM
,GRU
.rnn_size
: Number of RNN units.rnn_word_size
: Size of word embedding.rnn_init_method
: RNN init method. Choices areproject_hidden
,first_input
.rnn_recurr_dropout
: IfTrue
, enable variational recurrent dropout.
- Attention:
attn_num_heads
: Number of attention heads.attn_context_layer
: IfTrue
, add linear projection after multi-head attention.attn_alignment_method
: Alignment / composition method. Choices areadd_LN
,add
,dot
.attn_probability_fn
: Attention map probability function. Choices aresoftmax
,sigmoid
.
- SCST:
scst_beam_size
: The beam size for SCST sampling.scst_weight_ciderD
: The weight for CIDEr-D metric during SCST training.scst_weight_bleu
: The weight for BLEU metrics during SCST training.
- Main:
infer_set
: The split to perform inference on. Choices aretest
,valid
,coco_test
,coco_valid
.coco_test
andcoco_valid
are for inferencing on the wholetest2014
andval2014
sets respectively. These are used for MS-COCO online server evaluation.infer_checkpoints_dir
: Directory containing the checkpoint files.infer_checkpoints
: Checkpoint numbers to be evaluated. Comma-separated.annotations_file
: Annotations / reference file for calculating scores.
- Inference parameters:
infer_beam_size
: Beam size of beam search. Pass1
for greedy search.infer_length_penalty_weight
: Length penalty weight used in beam search.infer_max_length
: Maximum caption length allowed during inference.batch_size_infer
: Inference batch size for parallelism.
Re-downloading can be avoided by:
- Editing
setup.sh
- Providing the path to the directory containing the dataset files
python coco_prepro.py --dataset_dir /path/to/coco/dataset
python insta_prepro.py --dataset_dir /path/to/insta/dataset
In the same way, both train.py
and infer.py
accept alternative dataset paths.
python train.py --dataset_dir /path/to/dataset
python infer.py --dataset_dir /path/to/dataset
This code assumes the following dataset directory structures:
{coco-folder}
+-- captions
| +-- {folder and files generated by coco_prepro.py}
+-- test2014
| +-- {image files}
+-- train2014
| +-- {image files}
+-- val2014
+-- {image files}
{insta-folder}
+-- captions
| +-- {folder and files generated by insta_prepro.py}
+-- images
| +-- {image files}
+-- json
+-- insta-caption-test1.json
+-- insta-caption-train.json
.
+-- common
| +-- {shared libraries and utility functions}
+-- datasets
| +-- preprocessing
| +-- {dataset pre-processing scripts}
+-- pretrained
| +-- {pre-trained checkpoints for some COMIC models. Details are provided in a separate README.}
+-- src
+-- {main scripts}
Thanks to the developers of:
- [attend2u]
- [coco-caption]
- [ruotianluo/self-critical.pytorch]
- [ruotianluo/cider]
- [weili-ict/SelfCriticalSequenceTraining-tensorflow]
- [tensorflow]
The project is open source under Apache-2.0 license (see the LICENSE
file).