Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edits to generic add, BlockDiagLinearOperator's matmul, and documentation #10

Merged
merged 2 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

[![Run Test Suite](https://github.com/cornellius-gp/linear_operator/actions/workflows/run_test_suite.yml/badge.svg)](https://github.com/cornellius-gp/linear_operator/actions/workflows/run_test_suite.yml)
[![Documentation Status](https://readthedocs.org/projects/linear-operator/badge/?version=latest)](https://linear-operator.readthedocs.io/en/latest/?badge=latest)

## Development
To run unit tests:
```
python -m unittest discover
```
2 changes: 1 addition & 1 deletion linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,7 @@ def __add__(self, other: Union[torch.Tensor, LinearOperator, float]) -> LinearOp
from .zero_linear_operator import ZeroLinearOperator

if isinstance(other, ZeroLinearOperator):
return self
return deepcopy(self)
elif isinstance(other, DiagLinearOperator):
return AddedDiagLinearOperator(self, other)
elif isinstance(other, RootLinearOperator):
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
def matmul(self, other):
from .diag_linear_operator import DiagLinearOperator

# this is trivial if we multiply two BlockDiagLinearOperator
if isinstance(other, BlockDiagLinearOperator):
# this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes
if isinstance(other, BlockDiagLinearOperator) and self.base_linear_op.shape == other.base_linear_op.shape:
return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op)
# special case if we have a DiagLinearOperator
if isinstance(other, DiagLinearOperator):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def cat(inputs, dim=0, output_device=None):

class CatLinearOperator(LinearOperator):
r"""
A `LinearOperator` that represents the concatenation of other lazy tensors.
A `LinearOperator` that represents the concatenation of other linear operators.
Each LinearOperator must have the same shape except in the concatenating
dimension.
Expand Down
10 changes: 5 additions & 5 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ConstantMulLinearOperator(LinearOperator):

.. note::

To element-wise multiply two lazy tensors, see :class:`linear_operator.lazy.MulLinearOperator`
To element-wise multiply two lazy tensors, see :class:`linear_operator.operators.MulLinearOperator`

Args:
base_linear_op (LinearOperator) or (b x n x m)): The base_lazy tensor
Expand All @@ -38,18 +38,18 @@ class ConstantMulLinearOperator(LinearOperator):

Example::

>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([1, 2, 3])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([1, 2, 3])
>>> constant = torch.tensor(1.2)
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[ 1.2, 2.4, 3.6 ]
>>> # [ 2.4, 1.2, 2.4 ]
>>> # [ 3.6, 2.4, 1.2 ]]
>>>
>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> constant = torch.tensor([1.2, 0.5])
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[[ 1.2, 2.4, 3.6 ]
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,13 +883,13 @@ def _test_triangular_linear_op_inv_quad_logdet(self):
linear_op = self.create_linear_op()
rootdecomp = linear_operator.root_decomposition(linear_op)

if isinstance(rootdecomp, linear_operator.lazy.CholLinearOperator):
if isinstance(rootdecomp, linear_operator.operators.CholLinearOperator):
chol = linear_operator.root_decomposition(linear_op).root.clone()
linear_operator.utils.memoize.clear_cache_hook(linear_op)
linear_operator.utils.memoize.add_to_cache(
linear_op,
"root_decomposition",
linear_operator.lazy.RootLinearOperator(chol),
linear_operator.operators.RootLinearOperator(chol),
)

_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/contour_integral_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def contour_integral_quad(
Performs :math:`\mathbf K^{1/2} \mathbf b` or :math:`\mathbf K^{-1/2} \mathbf b`
using contour integral quadrature.
:param linear_operator.lazy.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param linear_operator.operators.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param torch.Tensor rhs: Right hand side tensor :math:`\mathbf b`
:param bool inverse: (default False) whether to compute :math:`\mathbf K^{1/2} \mathbf b` (if False)
or `\mathbf K^{-1/2} \mathbf b` (if True)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def apply_permutation(
Broadcasting rules apply.
:param matrix: :math:`\mathbf K`
:type matrix: ~linear_operator.lazy.LinearOperator or ~torch.Tensor (... x n x n)
:type matrix: ~linear_operator.operators.LinearOperator or ~torch.Tensor (... x n x n)
:param left_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{left}`
:type left_permutation: ~torch.Tensor, optional (... x <= n)
:param right_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{right}`
Expand Down