diff --git a/model.py b/model.py index bc3e350..8a6bbf9 100644 --- a/model.py +++ b/model.py @@ -127,7 +127,7 @@ def forward(self, x, skpCn): """ # Bilinear interpolation with scaling 2. - x = F.interpolate(x, scale_factor=2, mode='bilinear') + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) # Convolution + Leaky ReLU x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) # Convolution + Leaky ReLU on (`x`, `skpCn`) @@ -281,7 +281,7 @@ def forward(self, img, flow): # stacking X and Y grid = torch.stack((x,y), dim=3) # Sample pixels using bilinear interpolation. - imgOut = torch.nn.functional.grid_sample(img, grid) + imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True) return imgOut @@ -358,4 +358,4 @@ def getWarpCoeff (indices, device): ind = indices.detach().numpy() C0 = 1 - t[ind] C1 = t[ind] - return torch.Tensor(C0)[None, None, None, :].permute(3, 0, 1, 2).to(device), torch.Tensor(C1)[None, None, None, :].permute(3, 0, 1, 2).to(device) \ No newline at end of file + return torch.Tensor(C0)[None, None, None, :].permute(3, 0, 1, 2).to(device), torch.Tensor(C1)[None, None, None, :].permute(3, 0, 1, 2).to(device)