Skip to content

Commit

Permalink
Merge pull request #39 from SebastianAment/permutation
Browse files Browse the repository at this point in the history
Adding Permutation Linear Operators
  • Loading branch information
Balandat authored Nov 16, 2022
2 parents 785014a + ff7868b commit 09b891d
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 3 deletions.
5 changes: 4 additions & 1 deletion linear_operator/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .low_rank_root_linear_operator import LowRankRootLinearOperator
from .matmul_linear_operator import MatmulLinearOperator
from .mul_linear_operator import MulLinearOperator
from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator
from .psd_sum_linear_operator import PsdSumLinearOperator
from .root_linear_operator import RootLinearOperator
from .sum_batch_linear_operator import SumBatchLinearOperator
Expand All @@ -47,6 +48,7 @@
"CholLinearOperator",
"ConstantDiagLinearOperator",
"ConstantMulLinearOperator",
"DenseLinearOperator",
"DiagLinearOperator",
"IdentityLinearOperator",
"InterpolatedLinearOperator",
Expand All @@ -60,12 +62,13 @@
"LowRankRootLinearOperator",
"MatmulLinearOperator",
"MulLinearOperator",
"DenseLinearOperator",
"PermutationLinearOperator",
"PsdSumLinearOperator",
"RootLinearOperator",
"SumLinearOperator",
"SumBatchLinearOperator",
"ToeplitzLinearOperator",
"TransposePermutationLinearOperator",
"TriangularLinearOperator",
"ZeroLinearOperator",
]
4 changes: 2 additions & 2 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2772,7 +2772,7 @@ def __torch_function__(
name = func.__name__.replace("linalg_", "linalg.")
arg_classes = ", ".join(arg.__class__.__name__ for arg in args)
kwarg_classes = ", ".join(f"{key}={val.__class__.__name__}" for key, val in kwargs.items())
raise NotImplementedError(f"torch.{name}({arg_classes}{kwarg_classes}) is not implemented.")
raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
# Hack: get the appropriate class function based on its name
# As a result, we will call the subclass method (when applicable) rather than the superclass method
func = getattr(cls, _HANDLED_SECOND_ARG_FUNCTIONS[func])
Expand All @@ -2782,7 +2782,7 @@ def __torch_function__(
name = func.__name__.replace("linalg_", "linalg.")
arg_classes = ", ".join(arg.__class__.__name__ for arg in args)
kwarg_classes = ", ".join(f"{key}={val.__class__.__name__}" for key, val in kwargs.items())
raise NotImplementedError(f"torch.{name}({arg_classes}{kwarg_classes}) is not implemented.")
raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
# Hack: get the appropriate class function based on its name
# As a result, we will call the subclass method (when applicable) rather than the superclass method
func = getattr(cls, _HANDLED_FUNCTIONS[func])
Expand Down
168 changes: 168 additions & 0 deletions linear_operator/operators/permutation_linear_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Optional, Tuple

import torch
from torch import Tensor

from ._linear_operator import LinearOperator


class AbstractPermutationLinearOperator(LinearOperator):
r"""Abstract base class for permutation operators.
Incorporates 1) square shape, 2) common input shape checking, and
3) the fact that permutation matrices' transposes are their inverses.
"""

def inverse(self):
return self._transpose_nonbatch()

def _solve(self, rhs: Tensor):
self._matmul_check_shape(rhs)
return self.inverse() @ rhs

def _matmul_check_shape(self, rhs: Tensor) -> None:
if rhs.shape[-2] != self.shape[-1]:
raise ValueError(
f"{rhs.shape[0] = } incompatible with first dimensions of"
f"permutation operator with shape {self.shape}."
)

def _matmul_batch_shape(self, rhs: Tensor) -> torch.Size:
return torch.broadcast_shapes(self.batch_shape, rhs.shape[:-2])


class PermutationLinearOperator(AbstractPermutationLinearOperator):
r"""LinearOperator that lazily represents a permutation matrix with O(n) memory.
Upon left-multiplication, it permutes the first non-batch dimension of a tensor.
Args:
- perm: A permutation tensor with which to permute the first non-batch
dimension through matmul. Should have integer elements. If perm
is multi-dimensional, the corresponding operator MVM broadcasts
the permutation along the batch dimensions. That is, if perm is two-
dimensional, perm[i, :] should be a permutation for every i.
- inv_perm: Optional tensor representing the inverse of perm, is computed
via a O(n log(n)) sort if not given.
- validate_args: Boolean
"""

def __init__(
self,
perm: Tensor,
inv_perm: Optional[Tensor] = None,
validate_args: bool = True,
):
if not isinstance(perm, Tensor):
raise ValueError("perm is not a Tensor.")

if inv_perm is not None:
if perm.shape != inv_perm.shape:
raise ValueError("inv_perm does not have the same shape as perm.")

batch_indices = self._batch_indexing_helper(perm.shape[:-1])
sorted_perm = perm[batch_indices + (inv_perm,)]
else:
sorted_perm, inv_perm = perm.sort(dim=-1)

if validate_args:
if torch.is_floating_point(sorted_perm) or torch.is_complex(sorted_perm):
raise ValueError("perm does not have integer elements.")

for i in range(sorted_perm.shape[-1]):
if (sorted_perm[..., i] != i).any():
raise ValueError(
f"Invalid perm-inv_perm input, index {i} missing or not at "
f"correct index for permutation with {perm.shape = }."
)

self.perm = perm
self.inv_perm = inv_perm
super().__init__(perm, inv_perm, validate_args=validate_args)

def _matmul(self, rhs: Tensor) -> Tensor:
# input rhs is guaranteed to be at least two-dimensional due to matmul implementation
self._matmul_check_shape(rhs)

# batch broadcasting logic
batch_shape = self._matmul_batch_shape(rhs)
expanded_rhs = rhs.expand(*batch_shape, *rhs.shape[-2:])
ndim = expanded_rhs.ndim

batch_indices = self._batch_indexing_helper(batch_shape)
batch_indices = tuple(index.unsqueeze(-1) for index in batch_indices) # expanding to non-batch dimensions
perm_indices = self.perm.unsqueeze(-1)
final_indices = torch.arange(rhs.shape[-1]).view((1,) * (ndim - 1) + (-1,))
indices = batch_indices + (perm_indices, final_indices)
return expanded_rhs[indices]

def _batch_indexing_helper(self, batch_shape: torch.Size) -> Tuple:
"""Creates a tuple of indices with broadcastable shapes to preserve the
batch dimensions when indexing into the non-batch dimensions with `perm`.
Args:
- batch_shape: the batch shape for which to generate the broadcastable indices.
"""
return tuple(
torch.arange(n).view(
(1,) * i + (-1,) + (1,) * (len(batch_shape) - i - 1) + (1,) # adding one non-batch dimension
)
for i, n in enumerate(batch_shape)
)

def _size(self) -> torch.Size:
return torch.Size((*self.perm.shape, self.perm.shape[-1]))

def _transpose_nonbatch(self):
return PermutationLinearOperator(perm=self.inv_perm, inv_perm=self.perm, validate_args=False)

def to_sparse(self) -> Tensor:
"""Returns a sparse CSR tensor that represents the PermutationLinearOperator."""
# crow_indices[i] is index where values of row i begin
return torch.sparse_csr_tensor(
crow_indices=torch.arange(self.shape[-1] + 1).expand(*self.batch_shape, -1).contiguous(),
col_indices=self.perm,
values=torch.ones_like(self.perm),
)


class TransposePermutationLinearOperator(AbstractPermutationLinearOperator):
r"""LinearOperator that represents a permutation matrix `P` with O(1) memory.
In particular, P satisfies
`P @ X.flatten(-2, -1) = X.transpose(-2, -1).flatten(-2, -1)`,
where `X` is an `m x m` matrix and P has size `n x n` where `n = m^2`.
Args:
- m: dimension on which the transpose operation is taking place. The size of
the permutation matrix that the operator represents is then `n = m^2`.
"""

def __init__(self, m: int):
if m < 1:
raise ValueError(f"m = {m} has to be a positive integer.")
super().__init__(m=m)
self.n = m * m # size of implicitly represented linear operator
self.m = m # (m, m) is size of the reshaped input which is transposed
self._dtype = type(m)

def _matmul(self, rhs: Tensor) -> Tensor:
self._matmul_check_shape(rhs)
return rhs.unflatten(dim=-2, sizes=(self.m, self.m)).transpose(-3, -2).flatten(start_dim=-3, end_dim=-2)

def _size(self) -> torch.Size:
return torch.Size((self.n, self.n))

def _transpose_nonbatch(self):
return self

@property
def dtype(self):
return self._dtype

def type(self, dtype: torch.dtype) -> LinearOperator:
self._dtype = dtype
return self

@property
def device(self):
return None
94 changes: 94 additions & 0 deletions test/operators/test_permutation_linear_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3

import unittest

import torch

from linear_operator.operators import PermutationLinearOperator, TransposePermutationLinearOperator


class TestPermutationLinearOperator(unittest.TestCase):
def test_permutation_linear_operator(self):
with self.assertRaisesRegex(ValueError, "perm is not a Tensor."):
PermutationLinearOperator([1, 3, 5])

with self.assertRaisesRegex(ValueError, "Invalid perm*"):
PermutationLinearOperator(torch.tensor([1, 3, 5]))

with self.assertRaisesRegex(ValueError, "Invalid perm*"):
PermutationLinearOperator(torch.tensor([0, 2, 1]), torch.tensor([0, 1, 2]))

n = 3
P = PermutationLinearOperator(torch.randperm(n))

b1 = 2
b2 = 5
batch_shapes = [(), (b1,), (b2, b1)]

permutations = [
torch.randperm(n),
torch.cat(tuple(torch.randperm(n).unsqueeze(0) for _ in range(b1)), dim=0),
]
operators = [PermutationLinearOperator(perm) for perm in permutations]
right_hand_sides = [torch.randn(n)] + [torch.randn(*batch_shape, n, 4) for batch_shape in batch_shapes]

for P in operators:
if torch.__version__ > "1.12":
D = P.to_dense()
S = P.to_sparse()
self.assertTrue(isinstance(S, torch.Tensor))
self.assertTrue(S.layout == torch.sparse_csr)
self.assertTrue(torch.equal(D, S.to_dense()))

for x in right_hand_sides:
batch_shape = torch.broadcast_shapes(P.batch_shape, x.shape[:-2])
expanded_x = x.expand(*batch_shape, *x.shape[-2:]).contiguous()
self.assertTrue(P._matmul_batch_shape(x) == batch_shape)
y = P @ x

# computed inverse permutation field sorts the permutation
perm_batch_indices = P._batch_indexing_helper(P.batch_shape)
self.assertTrue((P.perm[perm_batch_indices + (P.inv_perm,)] == torch.arange(n)).all())

# application of permutation operator correctly permutes the input
batch_indices = P._batch_indexing_helper(batch_shape)
indices = batch_indices + (P.perm, slice(None))
if x.ndim == 1:
expanded_x = expanded_x.unsqueeze(-1)
y = y.unsqueeze(-1)

xp = expanded_x[indices]
self.assertTrue(torch.equal(y, xp))

# inverse of permutation operator
P_inv = torch.inverse(P)
self.assertTrue(torch.equal(P_inv @ y, expanded_x))

# transpose of permutation operator is equal to its inverse
self.assertTrue(torch.equal(P.transpose(-1, -2).perm, P_inv.perm))


class TestTransposePermutationLinearOperator(unittest.TestCase):
def test_transpose_permutation_linear_operator(self):
m = 0
msg = "m*has to be a positive integer."
with self.assertRaisesRegex(ValueError, msg):
TransposePermutationLinearOperator(m)

m = 3
P = TransposePermutationLinearOperator(m)
n = m**2
self.assertTrue(P.shape == (n, n))

batch_shapes = [(), (2,), (5, 2)]
right_hand_sides = [torch.randn(n)] + [torch.randn(*batch_shape, n, 3) for batch_shape in batch_shapes]

for x in right_hand_sides:
flat_i = -2 if x.ndim > 1 else -1
X = x.unflatten(flat_i, (m, m))
Xt = X.transpose(flat_i - 1, flat_i)
xt = Xt.flatten(start_dim=flat_i - 1, end_dim=flat_i)
y = P @ x
self.assertTrue(torch.equal(y, xt))
self.assertTrue(P is P.inverse())
self.assertTrue((P @ y - x).abs().max() == 0)

0 comments on commit 09b891d

Please sign in to comment.