Skip to content

Commit

Permalink
Add empty_like for NF4Tensor to support offloading (#881)
Browse files Browse the repository at this point in the history
Driss, we can confirm semantics when you're back but I'm fairly confident this is okay
  • Loading branch information
janeyx99 authored Sep 16, 2024
1 parent 8aa6533 commit a584e24
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
10 changes: 10 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
run_tests,
)
from torchao.dtypes.nf4tensor import (
NF4Tensor,
linear_nf4,
to_nf4,
_INNER_TENSOR_NAMES_FOR_SHARDING,
Expand Down Expand Up @@ -270,6 +271,15 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):

torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_empty_like(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.rand(input_size, device="cuda"))
new_tensor = torch.empty_like(nf4_tensor, device="cpu")
self.assertTrue(isinstance(new_tensor, NF4Tensor))
self.assertEqual(new_tensor.get_device(), -1) # that it's on CPU
self.assertEqual(new_tensor.size(), nf4_tensor.size())


class TestFSDPOps(TestCase):
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
Expand Down
14 changes: 14 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,20 @@ def nf4_detach(aten_op, args, kwargs=None):
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))


@implements(
[
aten.empty_like.default,
]
)
def nf4_empty_like(aten_op, args, kwargs=None):
nf4tensor = args[0]
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
if kwargs is not None and len(kwargs):
for key, value in kwargs.items():
updated_attrs[key] = value
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))


@implements(
[
aten.split.Tensor,
Expand Down

0 comments on commit a584e24

Please sign in to comment.