Skip to content

Commit

Permalink
Merge pull request #37 from j-wilson/main
Browse files Browse the repository at this point in the history
Patches for `ConstantDiagLinearOperator._mul_constant` and `ConstantMulLinearOperator._getitem`.
  • Loading branch information
Balandat authored Nov 7, 2022
2 parents ae3f219 + 402a45b commit 7075542
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,11 @@ def _getitem(self, row_index, col_index, *batch_indices):
# NOTE TO FUTURE SELF:
# This custom __getitem__ is actually very important!
# It prevents constructing an InterpolatedLinearOperator when one isn't needed
# This affects runntimes by up to 5x on simple exact GPs
# This affects runtimes by up to 5x on simple exact GPs
# Run __getitem__ on the base_linear_op and the constant
base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices)
constant = self._constant.expand(self.batch_shape)[batch_indices]
constant = constant.view(*constant.shape, 1, 1)
return base_linear_op * constant
return type(self)(base_linear_op=base_linear_op, constant=constant)

def _matmul(self, rhs):
res = self.base_linear_op._matmul(rhs)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _expand_batch(self, batch_shape: torch.Size) -> "ConstantDiagLinearOperator"
return self.__class__(self.diag_values.expand(*batch_shape, 1), diag_shape=self.diag_shape)

def _mul_constant(self, constant: Tensor) -> "ConstantDiagLinearOperator":
return self.__class__(self.diag_values * constant, diag_shape=self.diag_shape)
return self.__class__(self.diag_values * constant.unsqueeze(-1), diag_shape=self.diag_shape)

def _mul_matrix(self, other: Union[Tensor, LinearOperator]) -> Union[Tensor, LinearOperator]:
if isinstance(other, ConstantDiagLinearOperator):
Expand Down

0 comments on commit 7075542

Please sign in to comment.