Skip to content

Commit 98d77cf

Browse files
committed
load_data.py: fix typo
1 parent 04622c6 commit 98d77cf

File tree

1 file changed

+95
-95
lines changed

1 file changed

+95
-95
lines changed

load_data.py

+95-95
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,96 @@
1-
from mnist import MNIST
2-
from utils import *
3-
from six.moves.urllib.request import urlopen
4-
import gzip, tarfile
5-
from shutil import move
6-
7-
try:
8-
from StringIO import StringIO
9-
except ImportError:
10-
from io import StringIO
11-
12-
SOURCE_URL_MNIST = 'http://yann.lecun.com/exdb/mnist/'
13-
SOURCE_URL_CIFAR10 = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
14-
SOURCE_URL_OXFLOWER17 = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz'
15-
16-
MNIST_FILES = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
17-
18-
CIFAR10_TRAIN_DATASETS = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', ]
19-
CIFAR10_TEST_DATASETS = ['test_batch']
20-
CIFAR_10_GZ_FILE = 'cifar-10-python.tar.gz'
21-
CIFAR_10_FOLDER = 'cifar-10-batches-py/'
22-
23-
def unzip_download(download_response):
24-
compressedFile = StringIO()
25-
compressedFile.write(download_response.read())
26-
compressedFile.seek(0)
27-
decompressedFile = gzip.GzipFile(fileobj=compressedFile, mode='rb')
28-
return decompressedFile
29-
30-
def mnist(input_folder, image_width, image_height, image_depth):
31-
if not os.path.exists(input_folder):
32-
os.mkdir(input_folder)
33-
34-
for filename in MNIST_FILES:
35-
unzipped_filename = filename.split('.')[0]
36-
if unzipped_filename not in os.listdir(input_folder):
37-
print('Downloading MNIST file ', filename)
38-
response = urlopen(SOURCE_URL_MNIST + filename)
39-
with open(input_folder + unzipped_filename, 'wb') as outfile:
40-
outfile.write(gzip.decompress(response.read()))
41-
print('Succesfully downloaded and unzipped', filename)
42-
print("Loading MNIST dataset...")
43-
mndata = MNIST(input_folder)
44-
train_dataset_, train_labels_ = mndata.load_training()
45-
test_dataset_, test_labels_ = mndata.load_testing()
46-
train_dataset, train_labels = reformat_data(train_dataset_, train_labels_, image_width, image_height, image_depth)
47-
test_dataset, test_labels = reformat_data(test_dataset_, test_labels_, image_width, image_height, image_depth)
48-
print("The MNIST training dataset contains {} images, each of size {}".format(train_dataset[:,:,:,:].shape[0], train_dataset[:,:,:,:].shape[1:]))
49-
print("The MNIST test dataset contains {} images, each of size {}".format(test_dataset[:,:,:,:].shape[0], test_dataset[:,:,:,:].shape[1:]))
50-
print("There are {} number of labels.".format(len(np.unique(train_labels_))))
1+
from mnist import MNIST
2+
from utils import *
3+
from six.moves.urllib.request import urlopen
4+
import gzip, tarfile
5+
from shutil import move
6+
7+
try:
8+
from StringIO import StringIO
9+
except ImportError:
10+
from io import StringIO
11+
12+
SOURCE_URL_MNIST = 'http://yann.lecun.com/exdb/mnist/'
13+
SOURCE_URL_CIFAR10 = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
14+
SOURCE_URL_OXFLOWER17 = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz'
15+
16+
MNIST_FILES = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
17+
18+
CIFAR10_TRAIN_DATASETS = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', ]
19+
CIFAR10_TEST_DATASETS = ['test_batch']
20+
CIFAR_10_GZ_FILE = 'cifar-10-python.tar.gz'
21+
CIFAR_10_FOLDER = 'cifar-10-batches-py/'
22+
23+
def unzip_download(download_response):
24+
compressedFile = StringIO()
25+
compressedFile.write(download_response.read())
26+
compressedFile.seek(0)
27+
decompressedFile = gzip.GzipFile(fileobj=compressedFile, mode='rb')
28+
return decompressedFile
29+
30+
def mnist(input_folder, image_width, image_height, image_depth):
31+
if not os.path.exists(input_folder):
32+
os.makedirs(input_folder)
33+
34+
for filename in MNIST_FILES:
35+
unzipped_filename = filename.split('.')[0]
36+
if unzipped_filename not in os.listdir(input_folder):
37+
print('Downloading MNIST file ', filename)
38+
response = urlopen(SOURCE_URL_MNIST + filename)
39+
with open(input_folder + unzipped_filename, 'wb') as outfile:
40+
outfile.write(gzip.decompress(response.read()))
41+
print('Succesfully downloaded and unzipped', filename)
42+
print("Loading MNIST dataset...")
43+
mndata = MNIST(input_folder)
44+
train_dataset_, train_labels_ = mndata.load_training()
45+
test_dataset_, test_labels_ = mndata.load_testing()
46+
train_dataset, train_labels = reformat_data(train_dataset_, train_labels_, image_width, image_height, image_depth)
47+
test_dataset, test_labels = reformat_data(test_dataset_, test_labels_, image_width, image_height, image_depth)
48+
print("The MNIST training dataset contains {} images, each of size {}".format(train_dataset[:,:,:,:].shape[0], train_dataset[:,:,:,:].shape[1:]))
49+
print("The MNIST test dataset contains {} images, each of size {}".format(test_dataset[:,:,:,:].shape[0], test_dataset[:,:,:,:].shape[1:]))
50+
print("There are {} number of labels.".format(len(np.unique(train_labels_))))
51+
return train_dataset, train_labels, test_dataset, test_labels
52+
53+
def cifar10(input_folder, image_width, image_height, image_depth):
54+
if not os.path.exists(input_folder):
55+
os.mkdir(input_folder)
56+
57+
download_flag = False
58+
for file in [CIFAR_10_GZ_FILE] + CIFAR10_TRAIN_DATASETS + CIFAR10_TEST_DATASETS:
59+
if file not in os.listdir(input_folder):
60+
download_flag = True
61+
62+
if download_flag:
63+
print("Downloading CIFAR10 dataset")
64+
response = urlopen(SOURCE_URL_CIFAR10)
65+
with open(input_folder + CIFAR_10_GZ_FILE, 'wb') as outfile:
66+
outfile.write(response.read())
67+
print('Succesfully downloaded and unzipped', CIFAR_10_GZ_FILE)
68+
print("extracting files...")
69+
tar = tarfile.open(input_folder + CIFAR_10_GZ_FILE)
70+
tar.extractall(input_folder)
71+
files = os.listdir(input_folder + CIFAR_10_FOLDER)
72+
for file in files:
73+
move(input_folder + CIFAR_10_FOLDER + file, input_folder + file)
74+
os.rmdir(input_folder + CIFAR_10_FOLDER)
75+
print("Loading Cifar-10 dataset")
76+
with open(input_folder + CIFAR10_TEST_DATASETS[0], 'rb') as f0:
77+
c10_test_dict = pickle.load(f0, encoding='bytes')
78+
79+
c10_test_dataset, c10_test_labels = c10_test_dict[b'data'], c10_test_dict[b'labels']
80+
81+
c10_train_dataset, c10_train_labels = [], []
82+
for train_dataset in CIFAR10_TRAIN_DATASETS:
83+
with open(input_folder + train_dataset, 'rb') as f0:
84+
c10_train_dict = pickle.load(f0, encoding='bytes')
85+
c10_train_dataset_, c10_train_labels_ = c10_train_dict[b'data'], c10_train_dict[b'labels']
86+
87+
c10_train_dataset.append(c10_train_dataset_)
88+
c10_train_labels += c10_train_labels_
89+
90+
c10_train_dataset = np.concatenate(c10_train_dataset, axis=0)
91+
test_dataset, test_labels = reformat_data(c10_test_dataset, c10_test_labels, image_width, image_height, image_depth)
92+
train_dataset, train_labels = reformat_data(c10_train_dataset, c10_train_labels, image_width, image_height, image_depth)
93+
print("The CIFAR-10 training dataset contains {} images, each of size {}".format(train_dataset[:,:,:,:].shape[0], train_dataset[:,:,:,:].shape[1:]))
94+
print("The CIFAR-10 test dataset contains {} images, each of size {}".format(test_dataset[:,:,:,:].shape[0], test_dataset[:,:,:,:].shape[1:]))
95+
print("There are {} number of labels.".format(len(np.unique(c10_train_labels))))
5196
return train_dataset, train_labels, test_dataset, test_labels
52-
53-
def cifar10(input_folder, image_width, image_height, image_depth):
54-
if not os.path.exists(input_folder):
55-
os.mkdir(input_folder)
56-
57-
download_flag = False
58-
for file in [CIFAR_10_GZ_FILE] + CIFAR10_TRAIN_DATASETS + CIFAR10_TEST_DATASETS:
59-
if file not in os.listdir(input_folder):
60-
download_flag = True
61-
62-
if download_flag:
63-
print("Downloading CIFAR10 dataset")
64-
response = urlopen(SOURCE_URL_CIFAR10)
65-
with open(input_folder + CIFAR_10_GZ_FILE, 'wb') as outfile:
66-
outfile.write(response.read())
67-
print('Succesfully downloaded and unzipped', CIFAR_10_GZ_FILE)
68-
print("extracting files...")
69-
tar = tarfile.open(input_folder + CIFAR_10_GZ_FILE)
70-
tar.extractall(input_folder)
71-
files = os.listdir(input_folder + CIFAR_10_FOLDER)
72-
for file in files:
73-
move(input_folder + CIFAR_10_FOLDER + file, input_folder + file)
74-
os.rmdir(input_folder + CIFAR_10_FOLDER)
75-
print("Loading Cifar-10 dataset")
76-
with open(input_folder + CIFAR10_TEST_DATASETS[0], 'rb') as f0:
77-
c10_test_dict = pickle.load(f0, encoding='bytes')
78-
79-
c10_test_dataset, c10_test_labels = c10_test_dict[b'data'], c10_test_dict[b'labels']
80-
81-
c10_train_dataset, c10_train_labels = [], []
82-
for train_dataset in CIFAR10_TRAIN_DATASETS:
83-
with open(input_folder + train_dataset, 'rb') as f0:
84-
c10_train_dict = pickle.load(f0, encoding='bytes')
85-
c10_train_dataset_, c10_train_labels_ = c10_train_dict[b'data'], c10_train_dict[b'labels']
86-
87-
c10_train_dataset.append(c10_train_dataset_)
88-
c10_train_labels += c10_train_labels_
89-
90-
c10_train_dataset = np.concatenate(c10_train_dataset, axis=0)
91-
test_dataset, test_labels = reformat_data(c10_test_dataset, c10_test_labels, image_width, image_height, image_depth)
92-
train_dataset, train_labels = reformat_data(c10_train_dataset, c10_train_labels, image_width, image_height, image_depth)
93-
print("The CIFAR-10 training dataset contains {} images, each of size {}".format(train_dataset[:,:,:,:].shape[0], train_dataset[:,:,:,:].shape[1:]))
94-
print("The CIFAR-10 test dataset contains {} images, each of size {}".format(test_dataset[:,:,:,:].shape[0], test_dataset[:,:,:,:].shape[1:]))
95-
print("There are {} number of labels.".format(len(np.unique(c10_train_labels))))
96-
return train_dataset, train_labels, test_dataset, test_labels

0 commit comments

Comments
 (0)