Remove non-parallel approaches to get a device #8827
Labels
enhancement
New feature or request
usability
Bugs/features related to improving the usability of PyTorch/XLA
🐛 Bug
There are four ways to get a device object. Three could be deprecated as they are likely unnecessary, and also anti-parallel to torch.cuda.device(), which does not return a device but is in fact a context manager.
torch.device('xla', 3)
torch_xla.device(3)
torch_xla.torch_xla.device(3)
torch_xla.core.xla_model.xla_device(3)
To Reproduce
Run this in a collab with a TPU enabled and pip installed pytorch_xla with the correct version.
torch.device('xla', 3)
torch_xla.device(3)
torch_xla.torch_xla.device(3)
torch_xla.core.xla_model.xla_device(3)
Expected behavior
Should deprecate the three ways to get a device inside torch_xla.
Docs that use them should be updated to
torch.device("xla", <optional int>)
and remove import of torch_xla.core.xla_model if necessary.Environment
TPU on Collab
Additional context
Should verify that the four APIs do ultimately call the same underlying function to guarantee equivalent behavior.
The text was updated successfully, but these errors were encountered: