-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
148 lines (124 loc) · 4.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torchvision
import matplotlib.pyplot as plt
import torch
from vae import VAE
def calculate_kl_loss(means, stds):
return -1 * torch.sum(1.0 + 2.0 * stds - means.pow(2) - torch.exp(2.0 * stds), dim=1)
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
"""
# Define the path to the Tiny ImageNet dataset on your machine
dataset_path = os.path.expanduser(
os.path.join(
'~',
'Desktop',
'tiny-imagenet-200'
)
)
# Define transformations for the images
transform = transforms.Compose([
transforms.Resize((64, 64)), # Resize images to a consistent size
transforms.ToTensor(),
])
# Use ImageFolder to load the dataset
tinyimagenet_train_dataset = ImageFolder(root=dataset_path + '/train', transform=transform)
tinyimagenet_test_dataset = ImageFolder(root=dataset_path + '/test', transform=transform)
# Create a DataLoader for the dataset
batch_size = 256
train_loader = DataLoader(tinyimagenet_train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(tinyimagenet_test_dataset, batch_size=batch_size, shuffle=False)
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# use torchvision for mnist
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
train_dataset = torchvision.datasets.MNIST(
root='~/Desktop/mnist',
train=True,
transform=transform,
)
test_dataset = torchvision.datasets.MNIST(
root='~/Desktop/mnist',
train=False,
transform=transform,
)
batch_size = 256
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
)
nr_epochs = 10
# Instantiate the model
input_channels = 1 # Assuming RGB images
latent_size = 20 # Adjust the latent size based on your requirements
nr_classes = 10
model = VAE(input_channels, latent_size, nr_classes)
model.train()
# adam optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# cosine learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader) * nr_epochs)
mean_shape = None
model.to(device)
# Example of iterating through the dataset
load = False
if not load:
for epoch in range(nr_epochs):
epoch_loss = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad(set_to_none=True)
means, stds, x = model(images, labels)
# keeping the abs if we switch to L1 loss
reconstruction_loss = torch.abs(images - x)
reconstruction_loss = reconstruction_loss.pow(2)
reconstruction_loss = torch.sum(reconstruction_loss, dim=(1, 2, 3))
kl_loss = calculate_kl_loss(means.view(x.size(0), -1), stds.view(x.size(0), -1))
# beta_term = 4 * latent_size / (images.size(2) * images.size(3))
beta_term = 0.5
loss = torch.mean(reconstruction_loss + beta_term * kl_loss)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f'Epoch: {epoch + 1}, Loss: {loss.item():.3f}, Reconstruction: {torch.mean(reconstruction_loss).item():.3f}, KL: {torch.mean(kl_loss).item():.3f}')
scheduler.step()
epoch_loss /= len(train_loader)
print(f'Epoch: {epoch + 1}, Loss: {epoch_loss:.3f}')
# save model
torch.save(model.state_dict(), 'model.pth')
else:
model.load_state_dict(torch.load('model.pth'))
model.eval()
with torch.no_grad():
initial_z = torch.randn(64, latent_size)
initial_z = initial_z.to(device)
class_info = torch.full((64,), 2)
class_info = class_info.to(device)
y = model.class_embedding(class_info)
initial_z = initial_z + y
x = model.dec_full_layer(initial_z)
x = model.decoder_map(x)
x = model.relu_activation(x)
x = x.view(-1, *model.decoder_shape)
x = model.decoder(x)
x = x.cpu().detach()
# show images with pytorch
# reverse normalization
#x = x * 0.3081 + 0.1307
# reverse normalization
grid_img = torchvision.utils.make_grid(x, nrow=8)
plt.figure(figsize=(20, 20))
plt.imshow(grid_img.permute(1, 2, 0))
plt.tight_layout()
plt.show()