Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Balandat authored Nov 7, 2022
2 parents 4e1eae3 + ae3f219 commit 402a45b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
4 changes: 2 additions & 2 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ def mul(self, other: Union[float, torch.Tensor, "LinearOperator"]) -> LinearOper
other = torch.tensor(other, dtype=self.dtype, device=self.device)

try:
torch.broadcast_shapes(self.shape, other.shape)
broadcast_shape = torch.broadcast_shapes(self.shape, other.shape)
except RuntimeError:
raise RuntimeError(
"Cannot multiply LinearOperator of size {} by an object of size {}".format(self.shape, other.shape)
Expand All @@ -1761,7 +1761,7 @@ def mul(self, other: Union[float, torch.Tensor, "LinearOperator"]) -> LinearOper
if torch.is_tensor(other):
if other.numel() == 1:
return self._mul_constant(other.squeeze())
elif other.shape[-2:] == torch.Size((1, 1)):
elif other.shape[-2:] == torch.Size((1, 1)) and self.batch_shape == broadcast_shape[:-2]:
return self._mul_constant(other.view(*other.shape[:-2]))

return self._mul_matrix(to_linear_operator(other))
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/batch_repeat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, base_linear_op, batch_repeat=torch.Size((1,))):
)
if isinstance(base_linear_op, BatchRepeatLinearOperator):
raise RuntimeError(
"BatchRepeatLinearOperator recieved the following args:\n"
"BatchRepeatLinearOperator received the following args:\n"
"base_linear_op: {} (size: {}), batch_repeat: {}.".format(
base_linear_op, base_linear_op.shape, batch_repeat
)
Expand Down
23 changes: 20 additions & 3 deletions test/operators/test_constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from linear_operator.operators import ToeplitzLinearOperator
from linear_operator.operators import DenseLinearOperator, ToeplitzLinearOperator
from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase
from linear_operator.utils.toeplitz import sym_toeplitz

Expand Down Expand Up @@ -42,7 +42,7 @@ def evaluate_linear_op(self, linear_op):

class TestConstantMulLinearOperatorMultiBatch(LinearOperatorTestCase, unittest.TestCase):
seed = 0
# Because these LTs are large, we'll skil the big tests
# Because these LTs are large, we'll skip the big tests
should_test_sample = False
skip_slq_tests = True

Expand All @@ -60,7 +60,7 @@ def evaluate_linear_op(self, linear_op):

class TestConstantMulLinearOperatorMultiBatchBroadcastConstant(LinearOperatorTestCase, unittest.TestCase):
seed = 0
# Because these LTs are large, we'll skil the big tests
# Because these LTs are large, we'll skip the big tests
should_test_sample = False
skip_slq_tests = True

Expand All @@ -76,5 +76,22 @@ def evaluate_linear_op(self, linear_op):
return toeplitz.to_dense() * constant


class TestConstantMulLinearOperatorBatchBroadcastOperator(LinearOperatorTestCase, unittest.TestCase):
"""Test which broadcasts the operator to match the constant tensor's batch size, see Github issue #33"""

seed = 0
should_test_sample = False
skip_slq_tests = True

def create_linear_op(self):
mat = torch.randn(5, 6)
mat = mat.matmul(mat.mT).reshape(1, 5, 5)
constant = torch.randn(2, 1, 1).abs()
return DenseLinearOperator(mat) * constant

def evaluate_linear_op(self, linear_op):
return linear_op.to_dense()


if __name__ == "__main__":
unittest.main()

0 comments on commit 402a45b

Please sign in to comment.