-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbrowser_loader.js
117 lines (92 loc) · 4.03 KB
/
browser_loader.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/**
* @license
*/
import * as tf from '@tensorflow/tfjs'
const IMAGE_SIZE = 784
const NUM_CLASSES = 10
const NUM_DATASET_ELEMENTS = 65000
const TRAIN_TEST_RATIO = 5 / 6
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS)
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS
const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'
const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'
/**
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
*
* NOTE: This will get much easier. For now, we do data fetching and
* manipulation manually.
*/
export class MnistData {
constructor() {
this.shuffledTrainIndex = 0
this.shuffledTestIndex = 0
}
async load() {
// Make a request for the MNIST sprited image.
const img = new Image()
const canvas = document.createElement('canvas')
const ctx = canvas.getContext('2d')
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = ''
img.onload = () => {
img.width = img.naturalWidth
img.height = img.naturalHeight
const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4)
const chunkSize = 5000
canvas.width = img.width
canvas.height = chunkSize
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize)
ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize)
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height)
for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer)
resolve()
}
img.src = MNIST_IMAGES_SPRITE_PATH
})
const labelsRequest = fetch(MNIST_LABELS_PATH)
const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest])
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer())
// Create shuffled indices into the train/test set for when we select a
// random dataset element for training / validation.
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS)
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS)
// Slice the the images and labels into train and test sets.
this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS)
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS)
this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS)
this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS)
}
nextTrainBatch(batchSize) {
return this.nextBatch(batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length
return this.trainIndices[this.shuffledTrainIndex]
})
}
nextTestBatch(batchSize) {
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length
return this.testIndices[this.shuffledTestIndex]
})
}
nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE)
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES)
for (let i = 0; i < batchSize; i++) {
const idx = index()
const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE)
batchImagesArray.set(image, i * IMAGE_SIZE)
const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES)
batchLabelsArray.set(label, i * NUM_CLASSES)
}
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE])
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES])
return { xs, labels }
}
}