-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
75 lines (65 loc) · 2.45 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms as T
from torchvision import datasets as dset
MNIST_TRN_TRANSFORM = T.Compose([
T.ToTensor()
])
MNIST_TST_TRANSFORM = T.Compose([
T.ToTensor()
])
CIFAR10_TRN_TRANSFORM = T.Compose([
T.RandomCrop(28),
T.ToTensor()
])
CIFAR10_TST_TRANSFORM = T.Compose([
T.CenterCrop(28),
T.ToTensor()
])
CIFAR10_CLASSES = (
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def get_mnist_dataset(trn_size=60000, tst_size=10000):
trainset = dset.MNIST(root='./datasets', train=True,
download=True, transform=MNIST_TRN_TRANSFORM)
trainset.train_data = trainset.train_data[:trn_size]
trainset.train_labels = trainset.train_labels[:trn_size]
testset = dset.MNIST(root='./datasets', train=False,
download=True, transform=MNIST_TST_TRANSFORM)
testset.test_data = testset.test_data[:tst_size]
testset.test_labels = testset.test_labels[:tst_size]
return trainset, testset
def get_cifar10_dataset(trn_size=60000, tst_size=10000):
data_transforms = {
'train': T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
'val': T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
}
trainset = dset.CIFAR10(root='./datasets', train=True,
download=True, transform=data_transforms['train'])
trainset.train_data = trainset.train_data[:trn_size]
trainset.train_labels = trainset.train_labels[:trn_size]
testset = dset.CIFAR10(root='./datasets', train=False,
download=True, transform=data_transforms['val'])
testset.test_data = testset.test_data[:tst_size]
testset.test_labels = testset.test_labels[:tst_size]
return trainset, testset
def get_data_loader(trainset, testset, batch_size=32):
trainloader = DataLoader(trainset, batch_size=batch_size,
shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size,
shuffle=False)
data_loader = {'train':trainloader,
'test':testloader}
return data_loader