1
- from mnist import MNIST
2
- from utils import *
3
- from six .moves .urllib .request import urlopen
4
- import gzip , tarfile
5
- from shutil import move
6
-
7
- try :
8
- from StringIO import StringIO
9
- except ImportError :
10
- from io import StringIO
11
-
12
- SOURCE_URL_MNIST = 'http://yann.lecun.com/exdb/mnist/'
13
- SOURCE_URL_CIFAR10 = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
14
- SOURCE_URL_OXFLOWER17 = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz'
15
-
16
- MNIST_FILES = ['train-images-idx3-ubyte.gz' , 'train-labels-idx1-ubyte.gz' , 't10k-images-idx3-ubyte.gz' , 't10k-labels-idx1-ubyte.gz' ]
17
-
18
- CIFAR10_TRAIN_DATASETS = ['data_batch_1' , 'data_batch_2' , 'data_batch_3' , 'data_batch_4' , 'data_batch_5' , ]
19
- CIFAR10_TEST_DATASETS = ['test_batch' ]
20
- CIFAR_10_GZ_FILE = 'cifar-10-python.tar.gz'
21
- CIFAR_10_FOLDER = 'cifar-10-batches-py/'
22
-
23
- def unzip_download (download_response ):
24
- compressedFile = StringIO ()
25
- compressedFile .write (download_response .read ())
26
- compressedFile .seek (0 )
27
- decompressedFile = gzip .GzipFile (fileobj = compressedFile , mode = 'rb' )
28
- return decompressedFile
29
-
30
- def mnist (input_folder , image_width , image_height , image_depth ):
31
- if not os .path .exists (input_folder ):
32
- os .mkdir (input_folder )
33
-
34
- for filename in MNIST_FILES :
35
- unzipped_filename = filename .split ('.' )[0 ]
36
- if unzipped_filename not in os .listdir (input_folder ):
37
- print ('Downloading MNIST file ' , filename )
38
- response = urlopen (SOURCE_URL_MNIST + filename )
39
- with open (input_folder + unzipped_filename , 'wb' ) as outfile :
40
- outfile .write (gzip .decompress (response .read ()))
41
- print ('Succesfully downloaded and unzipped' , filename )
42
- print ("Loading MNIST dataset..." )
43
- mndata = MNIST (input_folder )
44
- train_dataset_ , train_labels_ = mndata .load_training ()
45
- test_dataset_ , test_labels_ = mndata .load_testing ()
46
- train_dataset , train_labels = reformat_data (train_dataset_ , train_labels_ , image_width , image_height , image_depth )
47
- test_dataset , test_labels = reformat_data (test_dataset_ , test_labels_ , image_width , image_height , image_depth )
48
- print ("The MNIST training dataset contains {} images, each of size {}" .format (train_dataset [:,:,:,:].shape [0 ], train_dataset [:,:,:,:].shape [1 :]))
49
- print ("The MNIST test dataset contains {} images, each of size {}" .format (test_dataset [:,:,:,:].shape [0 ], test_dataset [:,:,:,:].shape [1 :]))
50
- print ("There are {} number of labels." .format (len (np .unique (train_labels_ ))))
1
+ from mnist import MNIST
2
+ from utils import *
3
+ from six .moves .urllib .request import urlopen
4
+ import gzip , tarfile
5
+ from shutil import move
6
+
7
+ try :
8
+ from StringIO import StringIO
9
+ except ImportError :
10
+ from io import StringIO
11
+
12
+ SOURCE_URL_MNIST = 'http://yann.lecun.com/exdb/mnist/'
13
+ SOURCE_URL_CIFAR10 = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
14
+ SOURCE_URL_OXFLOWER17 = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz'
15
+
16
+ MNIST_FILES = ['train-images-idx3-ubyte.gz' , 'train-labels-idx1-ubyte.gz' , 't10k-images-idx3-ubyte.gz' , 't10k-labels-idx1-ubyte.gz' ]
17
+
18
+ CIFAR10_TRAIN_DATASETS = ['data_batch_1' , 'data_batch_2' , 'data_batch_3' , 'data_batch_4' , 'data_batch_5' , ]
19
+ CIFAR10_TEST_DATASETS = ['test_batch' ]
20
+ CIFAR_10_GZ_FILE = 'cifar-10-python.tar.gz'
21
+ CIFAR_10_FOLDER = 'cifar-10-batches-py/'
22
+
23
+ def unzip_download (download_response ):
24
+ compressedFile = StringIO ()
25
+ compressedFile .write (download_response .read ())
26
+ compressedFile .seek (0 )
27
+ decompressedFile = gzip .GzipFile (fileobj = compressedFile , mode = 'rb' )
28
+ return decompressedFile
29
+
30
+ def mnist (input_folder , image_width , image_height , image_depth ):
31
+ if not os .path .exists (input_folder ):
32
+ os .makedirs (input_folder )
33
+
34
+ for filename in MNIST_FILES :
35
+ unzipped_filename = filename .split ('.' )[0 ]
36
+ if unzipped_filename not in os .listdir (input_folder ):
37
+ print ('Downloading MNIST file ' , filename )
38
+ response = urlopen (SOURCE_URL_MNIST + filename )
39
+ with open (input_folder + unzipped_filename , 'wb' ) as outfile :
40
+ outfile .write (gzip .decompress (response .read ()))
41
+ print ('Succesfully downloaded and unzipped' , filename )
42
+ print ("Loading MNIST dataset..." )
43
+ mndata = MNIST (input_folder )
44
+ train_dataset_ , train_labels_ = mndata .load_training ()
45
+ test_dataset_ , test_labels_ = mndata .load_testing ()
46
+ train_dataset , train_labels = reformat_data (train_dataset_ , train_labels_ , image_width , image_height , image_depth )
47
+ test_dataset , test_labels = reformat_data (test_dataset_ , test_labels_ , image_width , image_height , image_depth )
48
+ print ("The MNIST training dataset contains {} images, each of size {}" .format (train_dataset [:,:,:,:].shape [0 ], train_dataset [:,:,:,:].shape [1 :]))
49
+ print ("The MNIST test dataset contains {} images, each of size {}" .format (test_dataset [:,:,:,:].shape [0 ], test_dataset [:,:,:,:].shape [1 :]))
50
+ print ("There are {} number of labels." .format (len (np .unique (train_labels_ ))))
51
+ return train_dataset , train_labels , test_dataset , test_labels
52
+
53
+ def cifar10 (input_folder , image_width , image_height , image_depth ):
54
+ if not os .path .exists (input_folder ):
55
+ os .mkdir (input_folder )
56
+
57
+ download_flag = False
58
+ for file in [CIFAR_10_GZ_FILE ] + CIFAR10_TRAIN_DATASETS + CIFAR10_TEST_DATASETS :
59
+ if file not in os .listdir (input_folder ):
60
+ download_flag = True
61
+
62
+ if download_flag :
63
+ print ("Downloading CIFAR10 dataset" )
64
+ response = urlopen (SOURCE_URL_CIFAR10 )
65
+ with open (input_folder + CIFAR_10_GZ_FILE , 'wb' ) as outfile :
66
+ outfile .write (response .read ())
67
+ print ('Succesfully downloaded and unzipped' , CIFAR_10_GZ_FILE )
68
+ print ("extracting files..." )
69
+ tar = tarfile .open (input_folder + CIFAR_10_GZ_FILE )
70
+ tar .extractall (input_folder )
71
+ files = os .listdir (input_folder + CIFAR_10_FOLDER )
72
+ for file in files :
73
+ move (input_folder + CIFAR_10_FOLDER + file , input_folder + file )
74
+ os .rmdir (input_folder + CIFAR_10_FOLDER )
75
+ print ("Loading Cifar-10 dataset" )
76
+ with open (input_folder + CIFAR10_TEST_DATASETS [0 ], 'rb' ) as f0 :
77
+ c10_test_dict = pickle .load (f0 , encoding = 'bytes' )
78
+
79
+ c10_test_dataset , c10_test_labels = c10_test_dict [b'data' ], c10_test_dict [b'labels' ]
80
+
81
+ c10_train_dataset , c10_train_labels = [], []
82
+ for train_dataset in CIFAR10_TRAIN_DATASETS :
83
+ with open (input_folder + train_dataset , 'rb' ) as f0 :
84
+ c10_train_dict = pickle .load (f0 , encoding = 'bytes' )
85
+ c10_train_dataset_ , c10_train_labels_ = c10_train_dict [b'data' ], c10_train_dict [b'labels' ]
86
+
87
+ c10_train_dataset .append (c10_train_dataset_ )
88
+ c10_train_labels += c10_train_labels_
89
+
90
+ c10_train_dataset = np .concatenate (c10_train_dataset , axis = 0 )
91
+ test_dataset , test_labels = reformat_data (c10_test_dataset , c10_test_labels , image_width , image_height , image_depth )
92
+ train_dataset , train_labels = reformat_data (c10_train_dataset , c10_train_labels , image_width , image_height , image_depth )
93
+ print ("The CIFAR-10 training dataset contains {} images, each of size {}" .format (train_dataset [:,:,:,:].shape [0 ], train_dataset [:,:,:,:].shape [1 :]))
94
+ print ("The CIFAR-10 test dataset contains {} images, each of size {}" .format (test_dataset [:,:,:,:].shape [0 ], test_dataset [:,:,:,:].shape [1 :]))
95
+ print ("There are {} number of labels." .format (len (np .unique (c10_train_labels ))))
51
96
return train_dataset , train_labels , test_dataset , test_labels
52
-
53
- def cifar10 (input_folder , image_width , image_height , image_depth ):
54
- if not os .path .exists (input_folder ):
55
- os .mkdir (input_folder )
56
-
57
- download_flag = False
58
- for file in [CIFAR_10_GZ_FILE ] + CIFAR10_TRAIN_DATASETS + CIFAR10_TEST_DATASETS :
59
- if file not in os .listdir (input_folder ):
60
- download_flag = True
61
-
62
- if download_flag :
63
- print ("Downloading CIFAR10 dataset" )
64
- response = urlopen (SOURCE_URL_CIFAR10 )
65
- with open (input_folder + CIFAR_10_GZ_FILE , 'wb' ) as outfile :
66
- outfile .write (response .read ())
67
- print ('Succesfully downloaded and unzipped' , CIFAR_10_GZ_FILE )
68
- print ("extracting files..." )
69
- tar = tarfile .open (input_folder + CIFAR_10_GZ_FILE )
70
- tar .extractall (input_folder )
71
- files = os .listdir (input_folder + CIFAR_10_FOLDER )
72
- for file in files :
73
- move (input_folder + CIFAR_10_FOLDER + file , input_folder + file )
74
- os .rmdir (input_folder + CIFAR_10_FOLDER )
75
- print ("Loading Cifar-10 dataset" )
76
- with open (input_folder + CIFAR10_TEST_DATASETS [0 ], 'rb' ) as f0 :
77
- c10_test_dict = pickle .load (f0 , encoding = 'bytes' )
78
-
79
- c10_test_dataset , c10_test_labels = c10_test_dict [b'data' ], c10_test_dict [b'labels' ]
80
-
81
- c10_train_dataset , c10_train_labels = [], []
82
- for train_dataset in CIFAR10_TRAIN_DATASETS :
83
- with open (input_folder + train_dataset , 'rb' ) as f0 :
84
- c10_train_dict = pickle .load (f0 , encoding = 'bytes' )
85
- c10_train_dataset_ , c10_train_labels_ = c10_train_dict [b'data' ], c10_train_dict [b'labels' ]
86
-
87
- c10_train_dataset .append (c10_train_dataset_ )
88
- c10_train_labels += c10_train_labels_
89
-
90
- c10_train_dataset = np .concatenate (c10_train_dataset , axis = 0 )
91
- test_dataset , test_labels = reformat_data (c10_test_dataset , c10_test_labels , image_width , image_height , image_depth )
92
- train_dataset , train_labels = reformat_data (c10_train_dataset , c10_train_labels , image_width , image_height , image_depth )
93
- print ("The CIFAR-10 training dataset contains {} images, each of size {}" .format (train_dataset [:,:,:,:].shape [0 ], train_dataset [:,:,:,:].shape [1 :]))
94
- print ("The CIFAR-10 test dataset contains {} images, each of size {}" .format (test_dataset [:,:,:,:].shape [0 ], test_dataset [:,:,:,:].shape [1 :]))
95
- print ("There are {} number of labels." .format (len (np .unique (c10_train_labels ))))
96
- return train_dataset , train_labels , test_dataset , test_labels
0 commit comments