Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] SumBatchLinearOperator fails for high-order tensor #100

Open
lmao14 opened this issue Sep 9, 2024 · 0 comments
Open

[Bug] SumBatchLinearOperator fails for high-order tensor #100

lmao14 opened this issue Sep 9, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@lmao14
Copy link

lmao14 commented Sep 9, 2024

🐛 Bug

To reproduce

** Code snippet to reproduce **

import torch
import gpytorch
import linear_operator

kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([4, 3]),),
                                        batch_shape=torch.Size([4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)
print(kxx.shape)
print(kxx.to_dense().sum(0).shape)
print(kxx.sum(0).to_dense().shape)

torch.Size([4, 3, 2, 2])
torch.Size([3, 2, 2])
torch.Size([4, 5, 5])

kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([5, 4, 3]),),
                                        batch_shape=torch.Size([5, 4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)
print(kxx.sum(0).to_dense().shape)

** Stack trace/error message **

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[65], line 5
      3 X = torch.randn([2, 5])
      4 kxx = kern(X)
----> 5 print(kxx.sum(0).to_dense().shape)

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:2517, in LinearOperator.sum(self, dim)
   2515 # Otherwise: it's a batch dimension
   2516 elif dim < self.dim():
-> 2517     return self._sum_batch(dim)
   2518 else:
   2519     raise ValueError("Invalid dim ({}) for LinearOperator of size {}".format(orig_dim, self.shape))

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:861, in LinearOperator._sum_batch(self, dim)
    850 """
    851 Sum the LinearOperator across a batch dimension (supplied as a positive number).
    852 
   (...)
    857 :param dim: The (positive valued) dimension to sum
    858 """
    859 from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator
--> 861 return SumBatchLinearOperator(self, block_dim=dim)

File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/block_linear_operator.py:50, in BlockLinearOperator.__init__(self, base_linear_op, block_dim)
     48 if block_dim != -3:
     49     positive_block_dim = base_linear_op.dim() + block_dim
---> 50     base_linear_op = base_linear_op._permute_batch(
     51         *range(positive_block_dim),
     52         *range(positive_block_dim + 1, base_linear_op.dim() - 2),
     53         positive_block_dim,
     54     )
     55 super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op))
     56 self.base_linear_op = base_linear_op

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:248, in LinearOperator._permute_batch(self, *dims)
    246 if torch.is_tensor(component):
    247     extra_dims = range(len(dims), component.dim())
--> 248     components.append(component.permute(*dims, *extra_dims))
    249 elif isinstance(component, LinearOperator):
    250     components.append(component._permute_batch(*dims))

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3

System information

Please complete the following information:

  • LinearOperator Version 0.5.3
  • PyTorch Version 2.0.1
@lmao14 lmao14 added the bug Something isn't working label Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant