Skip to content

Commit

Permalink
Merge branch 'main' into issue_template_filenames
Browse files Browse the repository at this point in the history
  • Loading branch information
Balandat authored Nov 7, 2022
2 parents c8939ee + 7075542 commit 91b931a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,11 @@ def _getitem(self, row_index, col_index, *batch_indices):
# NOTE TO FUTURE SELF:
# This custom __getitem__ is actually very important!
# It prevents constructing an InterpolatedLinearOperator when one isn't needed
# This affects runntimes by up to 5x on simple exact GPs
# This affects runtimes by up to 5x on simple exact GPs
# Run __getitem__ on the base_linear_op and the constant
base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices)
constant = self._constant.expand(self.batch_shape)[batch_indices]
constant = constant.view(*constant.shape, 1, 1)
return base_linear_op * constant
return type(self)(base_linear_op=base_linear_op, constant=constant)

def _matmul(self, rhs):
res = self.base_linear_op._matmul(rhs)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _expand_batch(self, batch_shape: torch.Size) -> "ConstantDiagLinearOperator"
return self.__class__(self.diag_values.expand(*batch_shape, 1), diag_shape=self.diag_shape)

def _mul_constant(self, constant: Tensor) -> "ConstantDiagLinearOperator":
return self.__class__(self.diag_values * constant, diag_shape=self.diag_shape)
return self.__class__(self.diag_values * constant.unsqueeze(-1), diag_shape=self.diag_shape)

def _mul_matrix(self, other: Union[Tensor, LinearOperator]) -> Union[Tensor, LinearOperator]:
if isinstance(other, ConstantDiagLinearOperator):
Expand Down

0 comments on commit 91b931a

Please sign in to comment.