Skip to content

Commit

Permalink
Merge pull request #74 from CY-Zhang/main
Browse files Browse the repository at this point in the history
Fix type error when converting interpolated_linear_operator to double precision
  • Loading branch information
Balandat authored Aug 20, 2023
2 parents 4720378 + 9e458c2 commit d7b1988
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 0 deletions.
27 changes: 27 additions & 0 deletions linear_operator/operators/identity_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from jaxtyping import Float
from torch import Tensor

from ..utils.generic import _to_helper

from ..utils.getitem import _compute_getitem_size, _is_noop_index
from ..utils.memoize import cached
from ._linear_operator import IndexType, LinearOperator
Expand Down Expand Up @@ -255,3 +257,28 @@ def zero_mean_mvn_samples(
) -> Float[Tensor, "num_samples *batch N"]:
base_samples = torch.randn(num_samples, *self.shape[:-1], dtype=self.dtype, device=self.device)
return base_samples

def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]:

# Overwrite the to() method in _linear_operator to also convert the dtype and device saved in _kwargs.

device, dtype = _to_helper(*args, **kwargs)

new_args = []
new_kwargs = {}
for arg in self._args:
if hasattr(arg, "to"):
if hasattr(arg, "dtype") and arg.dtype.is_floating_point == dtype.is_floating_point:
new_args.append(arg.to(dtype=dtype, device=device))
else:
new_args.append(arg.to(device=device))
else:
new_args.append(arg)
for name, val in self._kwargs.items():
if hasattr(val, "to"):
new_kwargs[name] = val.to(dtype=dtype, device=device)
else:
new_kwargs[name] = val
new_kwargs["device"] = device
new_kwargs["dtype"] = dtype
return self.__class__(*new_args, **new_kwargs)
26 changes: 26 additions & 0 deletions linear_operator/operators/interpolated_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..utils import sparse
from ..utils.broadcasting import _pad_with_singletons
from ..utils.generic import _to_helper
from ..utils.getitem import _noop_index
from ..utils.interpolation import left_interp, left_t_interp
from ._linear_operator import IndexType, LinearOperator
Expand Down Expand Up @@ -454,3 +455,28 @@ def zero_mean_mvn_samples(
res = left_interp(self.left_interp_indices, self.left_interp_values, base_samples).contiguous()
batch_iter = tuple(range(res.dim() - 1))
return res.permute(-1, *batch_iter).contiguous()

def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]:

# Overwrite the to() method in _linear_operator to avoid converting index matrices to float.
# Will only convert both dtype and device when arg and dtype are both int/float.
# Otherwise, will only convert device.

device, dtype = _to_helper(*args, **kwargs)

new_args = []
new_kwargs = {}
for arg in self._args:
if hasattr(arg, "to"):
if hasattr(arg, "dtype") and arg.dtype.is_floating_point == dtype.is_floating_point:
new_args.append(arg.to(dtype=dtype, device=device))
else:
new_args.append(arg.to(device=device))
else:
new_args.append(arg)
for name, val in self._kwargs.items():
if hasattr(val, "to"):
new_kwargs[name] = val.to(dtype=dtype, device=device)
else:
new_kwargs[name] = val
return self.__class__(*new_args, **new_kwargs)
27 changes: 27 additions & 0 deletions linear_operator/operators/masked_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from jaxtyping import Bool, Float
from torch import Tensor

from ..utils.generic import _to_helper

from ._linear_operator import _is_noop_index, IndexType, LinearOperator


Expand Down Expand Up @@ -111,3 +113,28 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice

def _permute_batch(self, *dims: int) -> LinearOperator:
return self.__class__(self.base._permute_batch(*dims), self.row_mask, self.col_mask)

def to(self: Float[LinearOperator, "*batch M N"], *args, **kwargs) -> Float[LinearOperator, "*batch M N"]:

# Overwrite the to() method in _linear_operator to avoid converting mask matrices to float.
# Will only convert both dtype and device when arg's dtype is not torch.bool.
# Otherwise, will only convert device.

device, dtype = _to_helper(*args, **kwargs)

new_args = []
new_kwargs = {}
for arg in self._args:
if hasattr(arg, "to"):
if hasattr(arg, "dtype") and arg.dtype.is_floating_point == dtype.is_floating_point:
new_args.append(arg.to(dtype=dtype, device=device))
else:
new_args.append(arg.to(device=device))
else:
new_args.append(arg)
for name, val in self._kwargs.items():
if hasattr(val, "to"):
new_kwargs[name] = val.to(dtype=dtype, device=device)
else:
new_kwargs[name] = val
return self.__class__(*new_args, **new_kwargs)
10 changes: 10 additions & 0 deletions linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,3 +1247,13 @@ def test_svd(self):
for arg, arg_copy in zip(linear_op.representation(), linear_op_copy.representation()):
if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["svd"])

def test_to_double(self):
# test if the linear_op is still functional and converted to torch.float64 after calling to(torch.float64).
linear_op = self.create_linear_op()
try:
linear_op = linear_op.to(torch.float64)
linear_op.numpy()
except RuntimeError:
raise RuntimeError(f"Could not convert {type(linear_op)} to double.")
self.assertEqual(linear_op.dtype, torch.float64)
13 changes: 13 additions & 0 deletions test/operators/test_interpolated_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def evaluate_linear_op(self, linear_op):
actual = left_matrix.matmul(base_tensor).matmul(right_matrix.t())
return actual

def test_to_double(self):
# overwrite the test_to_double under LinearOperatorTestCase.
# specifically check if the index matrices are still integer after conversion.
linear_op = self.create_linear_op()
try:
linear_op = linear_op.to(torch.float64)
linear_op.numpy()
except RuntimeError:
raise RuntimeError(f"Could not convert {type(linear_op)} to double.")
self.assertEqual(linear_op.dtype, torch.float64)
self.assertFalse(linear_op.left_interp_indices.dtype.is_floating_point)
self.assertFalse(linear_op.right_interp_indices.dtype.is_floating_point)


class TestInterpolatedLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase):
seed = 0
Expand Down
13 changes: 13 additions & 0 deletions test/operators/test_masked_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ def evaluate_linear_op(self, linear_op):
base = linear_op.base.to_dense()
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]

def test_to_double(self):
# overwrite the test_to_double under LinearOperatorTestCase.
# specifically check if the mask matrices are still boolean after conversion.
linear_op = self.create_linear_op()
try:
linear_op = linear_op.to(torch.float64)
linear_op.numpy()
except RuntimeError:
raise RuntimeError(f"Could not convert {type(linear_op)} to double.")
self.assertEqual(linear_op.dtype, torch.float64)
self.assertEqual(linear_op.col_mask.dtype, torch.bool)
self.assertEqual(linear_op.row_mask.dtype, torch.bool)


class TestMaskedLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase):
seed = 2023
Expand Down

0 comments on commit d7b1988

Please sign in to comment.