Model from MIT HAN Lab implemented in PyTorch.
Python 3.5+ is required.
pip install -r requirements.txt
You can find PyTorch model definition in mit_vww_pytorch.py
.
To convert the model weights you need:
- Download a saved model
model_fp32.pb
from author's repository - Run
python load_weights_from_pb.py -m <path to the saved model>
. By default, checkpoint is saved tomitvww_pytorch.pth
.
Script will return tensors that were not initialized (only num_batches_tracked
tensors for BatchNorm
layers in our case. Other tensors should be initialized with TF counterparts)
Script for weights conversion from TF saved model to PyTorch checkpoint should work for different models. To convert the model you want you need:
- Implement model in PyTorch with the same structure as TF model.
- Adjust TF tensors' names to match them with tensors in PyTorch model if needed.
Optionally, you can run tests with PYTHONPATH=. pytest
to verify that outputs of PyTorch and TF models are the same.
You can save TF graph to visualize it with TensorBoard:
python visualize.py -m <path to the saved model> -s <directory to save tensorboard logs>
.
To visualize graph run: tensorboard --logdir <directory to save tensorboard logs>