Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Block matrix support #54

Open
hughsalimbeni opened this issue Mar 26, 2023 · 3 comments
Open

[Feature Request] Block matrix support #54

hughsalimbeni opened this issue Mar 26, 2023 · 3 comments
Labels
enhancement New feature or request

Comments

@hughsalimbeni
Copy link

hughsalimbeni commented Mar 26, 2023

🚀 Feature Request

Represent [TN, TM] tensors by TxT blocks of NxM lazy tensors. While block matrices are supported, the efficient representation is only when there is a diagonal structure over the T dimensions.

Motivation

Here is an example that linear_operator cannot deal with:

import torch
import itertools

T, N, M = 2, 4, 3
As = [torch.rand(N, M) for _ in range(T)]
Bs = [[torch.rand(M, M) for _ in range(T)] for _ in range(T)]
Cs = [torch.rand(N, N) for _ in range(T)]
L = torch.rand(T, T)

A_bl = torch.zeros((N * T, M * T))  # BlockDiag (non-square)
B_bl = torch.zeros((M * T, M * T))  # Dense
C_bl = torch.zeros((N * T, N * T))  # BlockDiag
L_bl = torch.kron(L, torch.eye(N))  # Kroneker

for t in range(T):
    A_bl[N * t : N * (t + 1), M * t : M * (t + 1)] = As[t]
    C_bl[N * t : N * (t + 1), N * t : N * (t + 1)] = Cs[t]

for t1, t2 in itertools.product(range(T), range(T)):
    B_bl[M * t1 : M * (t1 + 1), M * t2 : M * (t2 + 1)] = Bs[t1][t2]

# Desired calculation
print("inefficient method")
print(torch.diag(L_bl @ (C_bl + A_bl @ B_bl @ A_bl.T) @ L_bl.T))

This calculation turns up in some multi-output GP models. It has a straightforward efficient implementation:

M_diag = {}
# We only need the diagonal of each block of M
for t1, t2 in itertools.product(range(T), range(T)):
    r = (As[t1].T * (Bs[t1][t2] @ As[t2].T)).sum(0)
    if t1 == t2:
        r += torch.diag(Cs[t1])
    M_diag[(t1, t2)] = r

# The rotation is applied blockwise due to the kron structure
R = {}
for t in range(T):  # we don't need the off-diag blocks
    r = 0
    for i1, i2 in itertools.product(range(T), range(T)):
        r += L[t, i1] * M_diag[(i1, i2)] * L[t, i2]
    R[t] = r

print("fast way")
print(torch.concat([R[t] for t in range(T)]))

Currently, this calculation could be implemented inside linear_operator like this

from linear_operator.operators import (
    to_linear_operator,
    IdentityLinearOperator,
    BlockDiagLinearOperator,
    BlockLinearOperator,
    MatmulLinearOperator,
    KroneckerProductLinearOperator,
)


class BlockDiagLinearOperatorNonSquare(BlockLinearOperator):
    _add_batch_dim = BlockDiagLinearOperator._add_batch_dim
    _remove_batch_dim = BlockDiagLinearOperator._remove_batch_dim
    _get_indices = BlockDiagLinearOperator._get_indices
    _size = BlockDiagLinearOperator._size
    num_blocks = BlockDiagLinearOperator.num_blocks

    def __init__(self, base_linear_op, block_dim=-3):
        super().__init__(base_linear_op, block_dim)

A_lo = BlockDiagLinearOperatorNonSquare(torch.stack(As, 0))
B_lo = to_linear_operator(B_bl)
C_lo = BlockDiagLinearOperator(to_linear_operator(torch.stack(Cs, 0)))
L_lo = KroneckerProductLinearOperator(L, IdentityLinearOperator(N))

M = MatmulLinearOperator(A_lo, (MatmulLinearOperator(B_lo, A_lo.T)))

print("using linear operator, with to_dense()")
print(
    MatmulLinearOperator(L_lo, MatmulLinearOperator(C_lo + M, L_lo.T))
    .to_dense()
    .diagonal()
)

Removing the to_dense() gives an error, however.

Pitch

Add block linear operator class that can keep track of the [T, T] block structure, represented as T^2 lazy tensors of the same shape. Implement matrix multiplication between block matrices as the appropriate linear operators on the blocks.

As a work-around, I have written manual implementations of specific cases, such as above.

I'm willing to work on PR for this

Additional context

None

@hughsalimbeni hughsalimbeni added the enhancement New feature or request label Mar 26, 2023
@Balandat
Copy link
Collaborator

Thanks for the suggestion, @hughsalimbeni! @gpleiss, @jacobrgardner and I have talked in the past about expanding linear_operator beyond the current focus on square (really, symmetric PSD) matrices.

The BlockDiagLinearOperatorNonSquare extending BlockDiagLinearOperator seems like a nifty way of realizing this without a ton of refactoring, but ideally we'd rethink the inheritance structure in a way that we'd have something general like

LinearOperator -> BlockLinearOperator -> BlockDiagLinearOperator -> DiagLinearOperator

where operators are not assumed to be square (could just have a is_square property that computes from the trailing two dimensions) or symmetric or positive definite (those could also be properties).

This would of course a major redesign of the whole library and so presumably out of scope for what you're trying to achieve here. But adding your suggestion could be a step on the way to a more general setup, and could inform / be absorbed in a larger rewrite down the road. So I'm happy to help review a PR for this.

@Balandat
Copy link
Collaborator

cc @SebastianAment

@corwinjoy
Copy link
Contributor

Looks like a great addition. The key question is what functions need to be implemented to make this a reality. From the library description, we must implement:
_matmul
_transpose_nonbatch

I'm not sure what else makes sense. It seems like we might want
_diagonal
_root_decomposition?
_root_inv_decomposition?
_solve?
inv_quad_logdet?
_svd?
_symeig?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants