diff --git a/train.py b/train.py index 5744dba..2d132a4 100644 --- a/train.py +++ b/train.py @@ -148,7 +148,7 @@ def validate(): F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 - V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :]) + V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f) @@ -218,9 +218,6 @@ def validate(): valPSNR.append([]) iLoss = 0 - # Increment scheduler count - scheduler.step() - for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0): ## Getting the input and the target from the training set @@ -255,7 +252,7 @@ def validate(): # Extract optical flow residuals and visibility maps F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 - V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :]) + V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 # Get intermediate frames from the intermediate flows @@ -316,6 +313,9 @@ def validate(): iLoss = 0 start = time.time() + # Increment scheduler count + scheduler.step() + # Create checkpoint after every `args.checkpoint_epoch` epochs if ((epoch % args.checkpoint_epoch) == args.checkpoint_epoch - 1): dict1 = {