You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Instead of using PyTorch autograd and checkpointing, we'll investigate using jax.grad, jax.vjp, jax.remat etc. to control the rematerialization of a PyTorch model.
Motivation
JAX remat is more powerful than PyTorch autograd. For example, we can name individual tensors and selectively save/offload them. PyTorch does not support naming a tensor.
🚀 Feature
Instead of using PyTorch autograd and checkpointing, we'll investigate using
jax.grad
,jax.vjp
,jax.remat
etc. to control the rematerialization of a PyTorch model.Motivation
JAX remat is more powerful than PyTorch autograd. For example, we can name individual tensors and selectively save/offload them. PyTorch does not support naming a tensor.
Pitch
Something like https://github.com/tengyifei/playground/blob/master/torch-jax-autograd.ipynb combined with #8781 and torchax.
cc @qihqi
The text was updated successfully, but these errors were encountered: