We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
** Code snippet to reproduce **
import torch import gpytorch import linear_operator kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([4, 3]),), batch_shape=torch.Size([4, 3])) X = torch.randn([2, 5]) kxx = kern(X) print(kxx.shape) print(kxx.to_dense().sum(0).shape) print(kxx.sum(0).to_dense().shape)
torch.Size([4, 3, 2, 2]) torch.Size([3, 2, 2]) torch.Size([4, 5, 5])
kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([5, 4, 3]),), batch_shape=torch.Size([5, 4, 3])) X = torch.randn([2, 5]) kxx = kern(X) print(kxx.sum(0).to_dense().shape)
** Stack trace/error message **
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[65], line 5 3 X = torch.randn([2, 5]) 4 kxx = kern(X) ----> 5 print(kxx.sum(0).to_dense().shape) File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:2517, in LinearOperator.sum(self, dim) 2515 # Otherwise: it's a batch dimension 2516 elif dim < self.dim(): -> 2517 return self._sum_batch(dim) 2518 else: 2519 raise ValueError("Invalid dim ({}) for LinearOperator of size {}".format(orig_dim, self.shape)) File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:861, in LinearOperator._sum_batch(self, dim) 850 """ 851 Sum the LinearOperator across a batch dimension (supplied as a positive number). 852 (...) 857 :param dim: The (positive valued) dimension to sum 858 """ 859 from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator --> 861 return SumBatchLinearOperator(self, block_dim=dim) File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs) 43 else: 44 new_kwargs[name] = val ---> 46 return __orig_init__(self, *args, **new_kwargs) File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs) 43 else: 44 new_kwargs[name] = val ---> 46 return __orig_init__(self, *args, **new_kwargs) File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/block_linear_operator.py:50, in BlockLinearOperator.__init__(self, base_linear_op, block_dim) 48 if block_dim != -3: 49 positive_block_dim = base_linear_op.dim() + block_dim ---> 50 base_linear_op = base_linear_op._permute_batch( 51 *range(positive_block_dim), 52 *range(positive_block_dim + 1, base_linear_op.dim() - 2), 53 positive_block_dim, 54 ) 55 super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op)) 56 self.base_linear_op = base_linear_op File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:248, in LinearOperator._permute_batch(self, *dims) 246 if torch.is_tensor(component): 247 extra_dims = range(len(dims), component.dim()) --> 248 components.append(component.permute(*dims, *extra_dims)) 249 elif isinstance(component, LinearOperator): 250 components.append(component._permute_batch(*dims)) RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3
Please complete the following information:
The text was updated successfully, but these errors were encountered:
No branches or pull requests
🐛 Bug
To reproduce
** Code snippet to reproduce **
torch.Size([4, 3, 2, 2])
torch.Size([3, 2, 2])
torch.Size([4, 5, 5])
** Stack trace/error message **
System information
Please complete the following information:
The text was updated successfully, but these errors were encountered: