You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I would really appreciate any advice as this is my first time trying to integrate pyro into existing model architectures (i.e., I am probably missing something really simple...). Right now, I am able to perform inference with NUTS and obtain a posterior of model hyperparameters, although due to shape mismatches when computing the Kronecker product with the posterior of the B matrix I can't evaluate the model.
Maybe it has something to do with evaluation of the Kronecker product along the batch sample direction?
Minimal working example adjusted from documentation NUTS integration example:
import math
import torch
import gpytorch
import pyro
from pyro.infer.mcmc import NUTS, MCMC, HMC
from matplotlib import pyplot as plt
N = 100
train_x = torch.linspace(0, 1, N)[...,None]
train_y = torch.stack([
torch.sin(train_x * (2 * math.pi)) + torch.randn(N,1) * 0.2,
torch.cos(train_x * (2 * math.pi)) + torch.randn(N,1) * 0.2,
], -1).squeeze()
class ExactMultitaskGP(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood, latent, tasks):
super(ExactMultitaskGP, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.MultitaskMean(
gpytorch.means.ConstantMean(), num_tasks=tasks
)
latent_kernels = []
for i in range(latent):
latent_kernels.append(gpytorch.kernels.RBFKernel())
self.covar_module = gpytorch.kernels.LCMKernel(
latent_kernels, num_tasks=tasks, rank=tasks
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
import os
smoke_test = ('CI' in os.environ)
num_samples = 2 if smoke_test else 100
warmup_steps = 2 if smoke_test else 0
from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior
# Use a positive constraint instead of usual GreaterThan(1e-4) so that LogNormal has support over full range.
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=train_y.shape[1])
model = ExactMultitaskGP(train_x, train_y, likelihood, train_y.shape[1], train_y.shape[1])
for i in range(train_y.shape[1]):
model.mean_module.base_means[i].register_prior(f"mean_prior_{i}", UniformPrior(-5, 5), "constant")
model.covar_module.covar_module_list[i].data_covar_module.register_prior(
f"lengthscale_prior_{i}", UniformPrior(1e-4, 10), "lengthscale"
)
model.covar_module.covar_module_list[i].task_covar_module.register_prior(
f"covar_factor_prior_{i}", UniformPrior(1e-4, 10), "covar_factor"
)
model.likelihood.register_prior("noise_prior", UniformPrior(1e-4, 10), "noise")
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
def pyro_model(x, y):
with gpytorch.settings.fast_computations(False, False, False):
sampled_model = model.pyro_sample_from_prior()
output = sampled_model.likelihood(sampled_model(x))
pyro.sample("obs", output, obs=y)
return y
nuts_kernel = NUTS(pyro_model)
mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=smoke_test)
mcmc_run.run(train_x, train_y)
model.pyro_load_from_samples(mcmc_run.get_samples())
model.eval()
expanded_train_x = train_x.unsqueeze(0).repeat(num_samples, 1, 1)
output = model(expanded_train_x)
Traceback error:
Traceback (most recent call last):
Cell In[12], line 5
output = model(expanded_train_x)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\models\exact_gp.py:294 in __call__
self.prediction_strategy = prediction_strategy(
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\models\exact_prediction_strategies.py:37 in prediction_strategy
return cls(train_inputs, train_prior_dist, train_labels, likelihood)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\kernels\kernel.py:445 in prediction_strategy
return exact_prediction_strategies.DefaultPredictionStrategy(
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\models\exact_prediction_strategies.py:63 in __init__
mvn = self.likelihood(train_prior_dist, train_inputs)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\likelihoods\likelihood.py:367 in __call__
return self.marginal(input, *args, **kwargs) # pyre-ignore[6]
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\likelihoods\multitask_gaussian_likelihood.py:299 in marginal
return super().marginal(function_dist, *args, **kwargs)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\likelihoods\multitask_gaussian_likelihood.py:107 in marginal
covar = covar.evaluate_kernel()
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\utils\memoize.py:59 in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\lazy\lazy_evaluated_kernel_tensor.py:25 in wrapped
output = method(self, *args, **kwargs)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\lazy\lazy_evaluated_kernel_tensor.py:355 in evaluate_kernel
res = self.kernel(
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\kernels\kernel.py:530 in __call__
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\module.py:31 in __call__
outputs = self.forward(*inputs, **kwargs)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\kernels\lcm_kernel.py:57 in forward
res = self.covar_module_list[0].forward(x1, x2, **params)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\kernels\multitask_kernel.py:53 in forward
res = KroneckerProductLinearOperator(covar_x, covar_i)
File ~\anaconda3\envs\pytorch\Lib\site-packages\gpytorch\lazy\lazy_tensor.py:46 in __init__
return __orig_init__(self, *args, **new_kwargs)
File ~\anaconda3\envs\pytorch\Lib\site-packages\linear_operator\operators\kronecker_product_linear_operator.py:81 in __init__
raise RuntimeError(
RuntimeError: Batch shapes of LinearOperators ((100, 100, 100, 100), (100, 10000, 2, 2)) are incompatible for a Kronecker product.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello,
I would really appreciate any advice as this is my first time trying to integrate pyro into existing model architectures (i.e., I am probably missing something really simple...). Right now, I am able to perform inference with NUTS and obtain a posterior of model hyperparameters, although due to shape mismatches when computing the Kronecker product with the posterior of the B matrix I can't evaluate the model.
Maybe it has something to do with evaluation of the Kronecker product along the batch sample direction?
Minimal working example adjusted from documentation NUTS integration example:
Traceback error:
Beta Was this translation helpful? Give feedback.
All reactions