Skip to content

Commit

Permalink
Merge pull request #72 from Turakar/fix_masked_t_matmul
Browse files Browse the repository at this point in the history
Fix _t_matmul() in MaskedLinearOperator and add test
  • Loading branch information
Balandat authored Jul 20, 2023
2 parents e915cc0 + 8c48538 commit 6ea8866
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
4 changes: 2 additions & 2 deletions linear_operator/operators/masked_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _matmul(
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
rhs_expanded = self._expand(rhs, self.col_mask)
res_expanded = self.base.matmul(rhs_expanded)
res_expanded = self.base._matmul(rhs_expanded)
res = res_expanded[..., self.row_mask, :]

return res
Expand All @@ -60,7 +60,7 @@ def _t_matmul(
rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
rhs_expanded = self._expand(rhs, self.row_mask)
res_expanded = self.base.t_matmul(rhs_expanded)
res_expanded = self.base._t_matmul(rhs_expanded)
res = res_expanded[..., self.col_mask, :]
return res

Expand Down
14 changes: 14 additions & 0 deletions linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,20 @@ def test_matmul_matrix(self):
rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-1), 4)
return self._test_matmul(rhs)

def test_t_matmul_matrix(self):
with torch.no_grad():
linear_op = self.create_linear_op()
rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-2), 4)
linear_op_copy = torch.clone(linear_op)
evaluated = self.evaluate_linear_op(linear_op_copy)
rhs_evaluated = to_dense(rhs)

# Test operator
res = linear_op._t_matmul(rhs)
actual = evaluated.mT.matmul(rhs_evaluated)
res_evaluated = to_dense(res)
self.assertAllClose(res_evaluated, actual)

def test_rmatmul_matrix(self):
linear_op = self.create_linear_op()
lhs = torch.randn(*linear_op.batch_shape, 4, linear_op.size(-2))
Expand Down

0 comments on commit 6ea8866

Please sign in to comment.