Skip to content

tczhangzhi/VisionTransformer-Pytorch

Repository files navigation

Vision Transformer Pytorch

This project is modified from lukemelas/EfficientNet-PyTorch and asyml/vision-transformer-pytorch to provide out-of-box API for you to utilize VisionTransformer as easy as EfficientNet.

Quickstart

Install with pip install vision_transformer_pytorch and load a pretrained VisionTransformer with:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16')

About Vision Transformer PyTorch

Vision Transformer Pytorch is a PyTorch re-implementation of Vision Transformer based on one of the best practice of commonly utilized deep learning libraries, EfficientNet-PyTorch, and an elegant implement of VisionTransformer, vision-transformer-pytorch. In this project, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.

If you have any feature requests or questions, feel free to leave them as GitHub issues!

Installation

Install via pip:

pip install vision_transformer_pytorch

Or install from source:

git clone https://github.com/tczhangzhi/VisionTransformer-Pytorch
cd VisionTransformer-Pytorch
pip install -e .

Usage

Loading pretrained models

Load a Vision Transformer:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16') # or 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16'

Load a pretrained Vision Transformer:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16') # or 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16'
# inputs = torch.randn(1, 3, *model.image_size)
# model(inputs)
# model.extract_features(inputs)

Default hyper parameters:

Param\Model ViT-B_16 ViT-B_32 ViT-L_16 ViT-L_32 R50+ViT-B_16
image_size 384 384 384 384 384
patch_size 16 32 16 32 1
emb_dim 768 768 1024 1024 768
mlp_dim 3072 3072 4096 4096 3072
num_heads 12 12 16 16 12
num_layers 12 12 24 24 12
num_classes 1000 1000 1000 1000 1000
attn_dropout_rate 0.0 0.0 0.0 0.0 0.0
dropout_rate 0.1 0.1 0.1 0.1 0.1

If you need to modify these hyper parameters, please use:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16', image_size=256, patch_size=64, ...)

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!