You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have implemented a KeOps periodic kernel in cornellius-gp/gpytorch#2296 , however the raw_lengthscale parameter does not have gradients computed (see cornellius-gp/gpytorch#2296 (comment) ). I have managed to track the issue to the matrix multiplication implemented in LinearOperator.matmul, see Expected Behavior. This matrix multiplication is called when doing LinearOperator.sum (as in the example below).
To reproduce
** Code snippet to reproduce **
importtorchfromgpytorch.kernels.keopsimportPeriodicKernelasKeOpsPeriodicKernel#implementation from pull request 2296importgpytorchtorch.manual_seed(7)
M, N, D=1000, 2000, 3x=torch.randn(M, D).double()
y=torch.randn(N, D).double()
k=KeOpsPeriodicKernel(ard_num_dims=3).double()
k.lengthscale=torch.tensor(1.0).double()
k.period_length=torch.tensor(1.0).double()
# context manager used so that type(covar) is KeOpsLinearOpeartor, not LazyEvaluatedKernelTensorwithgpytorch.settings.lazily_evaluate_kernels(False):
covar=k(x, y)
print(type(covar))
# Calls `LinearOperator.sum``, which subsequently calls `LinearOperator.matmul`# `LinearOperator.matmul` uses a custom torch.Function for matrix multiplicationres2=covar.sum(dim=1) # res2 is a torch.Tensor hereres2=res2.sum()
print(res2)
g_x=torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
print(g_x)
** Stack trace/error message **
<class 'linear_operator.operators.keops_linear_operator.KeOpsLinearOperator'>
tensor(202237.5145, dtype=torch.float64, grad_fn=<SumBackward0>)
Traceback (most recent call last):
File "/home/julian/Desktop/test/keops_periodic_low_level/issue_keops_linear_operator.py", line 23, in <module>
g_x = torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
File "/home/julian/.venv/ichor/lib/python3.10/site-packages/torch/autograd/__init__.py", line 300, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
where k.raw_lengthscale is causing the issue.
Expected Behavior
Compute the gradients for both k.raw_lengthscale and k.raw_period_length. The exact place where the issue occurs is here
which calls the LinearOperator.matmul that loses track of the gradients. I am summing across the columns in the example, but the same issue occurs if summing across the rows. Adding the following check
# Case: summing across columnsifdim== (self.dim() -1):
ones=torch.ones(self.size(-1), 1, dtype=self.dtype, device=self.device)
from .keops_linear_operatorimportKeOpsLinearOperatorifisinstance(self, KeOpsLinearOperator):
returnself.covar_mat.sum(dim=1)
return (self @ ones).squeeze(-1)
gives gradients for both raw_lengthscale and raw_period_length as the custom Matmul is never called.
🐛 Bug
I have implemented a KeOps periodic kernel in cornellius-gp/gpytorch#2296 , however the
raw_lengthscale
parameter does not have gradients computed (see cornellius-gp/gpytorch#2296 (comment) ). I have managed to track the issue to the matrix multiplication implemented inLinearOperator.matmul
, see Expected Behavior. This matrix multiplication is called when doingLinearOperator.sum
(as in the example below).To reproduce
** Code snippet to reproduce **
** Stack trace/error message **
where
k.raw_lengthscale
is causing the issue.Expected Behavior
Compute the gradients for both
k.raw_lengthscale
andk.raw_period_length
. The exact place where the issue occurs is herelinear_operator/linear_operator/operators/_linear_operator.py
Line 2366 in 92f7e33
which calls the
LinearOperator.matmul
that loses track of the gradients. I am summing across the columns in the example, but the same issue occurs if summing across the rows. Adding the following checkgives gradients for both
raw_lengthscale
andraw_period_length
as the customMatmul
is never called.This is probably not the best solution, perhaps the
Matmul
forward/backward methods can be changed, so the gradients are computed correctly?As a check, the same numbers are returned if the normal periodic kernel is used
System information
Please complete the following information:
linear_operator version: 0.3.0
torch version: 1.13.1+cu117
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: