Skip to content

Commit 0c44284

Browse files
committed
Minor improvement.
1 parent cc160be commit 0c44284

File tree

8 files changed

+75
-104
lines changed

8 files changed

+75
-104
lines changed

chapter_deep-learning-computation/custom-layer.ipynb

-2
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,6 @@
299299
"source": [
300300
"NDArray input = manager.randomUniform(0, 1, new Shape(2, 5));\n",
301301
"\n",
302-
"linear.setInitializer(new XavierInitializer(), Parameter.Type.WEIGHT);\n",
303302
"linear.initialize(manager, DataType.FLOAT32, input.getShape());\n",
304303
"\n",
305304
"Model model = Model.newInstance(\"my-linear\");\n",
@@ -328,7 +327,6 @@
328327
"SequentialBlock net = new SequentialBlock();\n",
329328
"net.add(new MyLinear(8, 64)); // 64 units in -> 8 units out\n",
330329
"net.add(new MyLinear(1, 8)); // 8 units in -> 1 unit out\n",
331-
"net.setInitializer(new XavierInitializer(), Parameter.Type.WEIGHT);\n",
332330
"net.initialize(manager, DataType.FLOAT32, input.getShape());\n",
333331
"\n",
334332
"Model model = Model.newInstance(\"lin-reg-custom\");\n",

chapter_deep-learning-computation/parameters.ipynb

+17-42
Original file line numberDiff line numberDiff line change
@@ -332,18 +332,15 @@
332332
"\n",
333333
"This setup has the advantage that we don't have to worry about our `setInitializer()` overriding our previous `initializer`s on internal blocks!\n",
334334
"\n",
335-
"If you want to however, you can explicitly set an initializer for a `Parameter` by calling its `setInitializer()` function directly and passing in `true` to the overwrite input.\n",
336-
"Simply loop over all the parameters returned from `getParameters()` and set their initializers directly!"
335+
"If you want to however, you can explicitly set an initializer for a `Parameter` by calling its `setInitializer()` function directly."
337336
]
338337
},
339338
{
340339
"cell_type": "markdown",
341340
"metadata": {},
342341
"source": [
343-
"Let us begin by calling on built-in initializers. \n",
344-
"The code below initializes all parameters \n",
345-
"to a given constant value 1, \n",
346-
"by using the `ConstantInitializer()` initializer. \n",
342+
"Let us begin by calling on built-in initializers. The code below initializes all parameters \n",
343+
"to a given constant value 1, by using the `ConstantInitializer()` initializer. \n",
347344
"\n",
348345
"Note that this will not do anything currently since we have already set\n",
349346
"our initializer in the previous code block.\n",
@@ -430,7 +427,7 @@
430427
},
431428
"outputs": [],
432429
"source": [
433-
"SequentialBlock net = getNet();\n",
430+
"net = getNet();\n",
434431
"net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);\n",
435432
"net.initialize(manager, DataType.FLOAT32, x.getShape());\n",
436433
"Block linearLayer = net.getChildren().valueAt(0);\n",
@@ -444,7 +441,7 @@
444441
"source": [
445442
"We can also apply different initializers for certain Blocks.\n",
446443
"For example, below we initialize the first layer\n",
447-
"with the `Xavier` initializer\n",
444+
"with the `XavierInitializer` initializer\n",
448445
"and initialize the second layer \n",
449446
"to a constant value of 0.\n",
450447
"\n",
@@ -464,7 +461,7 @@
464461
},
465462
"outputs": [],
466463
"source": [
467-
"SequentialBlock net = new SequentialBlock();\n",
464+
"net = new SequentialBlock();\n",
468465
"Linear linear1 = Linear.builder().setUnits(8).build();\n",
469466
"net.add(linear1);\n",
470467
"net.add(Activation.reluBlock());\n",
@@ -485,15 +482,7 @@
485482
"cell_type": "markdown",
486483
"metadata": {},
487484
"source": [
488-
"Finally, we can loop over the `ParameterList` and set their initializers individually.\n",
489-
"When setting initializers directly on the `Parameter`, you must pass in an `overwrite`\n",
490-
"boolean along with the initializer to declare whether you want your current\n",
491-
"initializer to overwrite the previous initializer if one has already been set.\n",
492-
"Here, we do want to overwrite and so pass in `true`. \n",
493-
"\n",
494-
"For this example, however, since we haven't set the `weight` initializers before, there is no initializer to overwrite so we could pass in `false` and still have the same outcome.\n",
495-
"\n",
496-
"However, since `bias` parameters are automatically set to initialize at 0, to properly set our intializer here, we have to set overwrite to `true`."
485+
"Finally, we can directly access the `Parameter.setInitializer()` and set their initializers individually."
497486
]
498487
},
499488
{
@@ -502,30 +491,16 @@
502491
"metadata": {},
503492
"outputs": [],
504493
"source": [
505-
"SequentialBlock net = getNet();\n",
494+
"net = getNet();\n",
506495
"ParameterList params = net.getParameters();\n",
507-
"for (int i = 0; i < params.size(); i++) {\n",
508-
" // Here we interleave initializers.\n",
509-
" // We initialize parameters at even indexes to 0\n",
510-
" // and parameters at odd indexes to 2.\n",
511-
" Parameter param = params.valueAt(i);\n",
512-
" if (i % 2 == 0) {\n",
513-
" // All weight parameters happen to be at even indices.\n",
514-
" // We set them to initialize to 0.\n",
515-
" param.setInitializer(new ConstantInitializer(0));\n",
516-
" }\n",
517-
" else {\n",
518-
" // All bias parameters happen to be at odd indices.\n",
519-
" // We set them to initialize to 2.\n",
520-
" param.setInitializer(new ConstantInitializer(2));\n",
521-
" }\n",
522-
"}\n",
523-
"net.initialize(manager, DataType.FLOAT32, x.getShape());\n",
524496
"\n",
525-
"for (var param : net.getParameters()) {\n",
526-
" System.out.println(param.getKey());\n",
527-
" System.out.println(param.getValue().getArray());\n",
528-
"}"
497+
"params.get(\"01Linear_weight\").setInitializer(new NormalInitializer());\n",
498+
"params.get(\"03Linear_weight\").setInitializer(Initializer.ONES);\n",
499+
"\n",
500+
"net.initialize(manager, DataType.FLOAT32, new Shape(2, 4));\n",
501+
"\n",
502+
"System.out.println(params.valueAt(0).getArray());\n",
503+
"System.out.println(params.valueAt(2).getArray());"
529504
]
530505
},
531506
{
@@ -563,7 +538,7 @@
563538
"metadata": {},
564539
"outputs": [],
565540
"source": [
566-
"class MyInit implements Initializer {\n",
541+
"static class MyInit implements Initializer {\n",
567542
"\n",
568543
" public MyInit() {}\n",
569544
"\n",
@@ -593,7 +568,7 @@
593568
"metadata": {},
594569
"outputs": [],
595570
"source": [
596-
"SequentialBlock net = getNet();\n",
571+
"net = getNet();\n",
597572
"net.setInitializer(new MyInit(), Parameter.Type.WEIGHT);\n",
598573
"net.initialize(manager, DataType.FLOAT32, x.getShape());\n",
599574
"Block linearLayer = net.getChildren().valueAt(0);\n",

chapter_deep-learning-computation/read-write.ipynb

+25-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"try (FileOutputStream fos = new FileOutputStream(\"x-file\")) {\n",
6161
" fos.write(x.encode());\n",
6262
"}\n",
63-
"x;"
63+
"x"
6464
]
6565
},
6666
{
@@ -89,7 +89,30 @@
8989
" // from a `FileInputStream` and return it as a `byte[]`.\n",
9090
" x2 = NDArray.decode(manager, Utils.toByteArray(fis));\n",
9191
"}\n",
92-
"x2;"
92+
"x2"
93+
]
94+
},
95+
{
96+
"cell_type": "markdown",
97+
"metadata": {},
98+
"source": [
99+
"We can also store `NDList` into a file and load it back:"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"NDList list = new NDList(x, x2);\n",
109+
"try (FileOutputStream fos = new FileOutputStream(\"x-file\")) {\n",
110+
" fos.write(list.encode());\n",
111+
"}\n",
112+
"try (FileInputStream fis = new FileInputStream(\"x-file\")) {\n",
113+
" list = NDList.decode(manager, Utils.toByteArray(fis));\n",
114+
"}\n",
115+
"list"
93116
]
94117
},
95118
{

chapter_linear-networks/linear-regression-djl.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@
284284
"source": [
285285
"DefaultTrainingConfig config = new DefaultTrainingConfig(l2loss)\n",
286286
" .optOptimizer(sgd) // Optimizer (loss function)\n",
287-
" .optDevices(Engine.getInstance().getDevices(1)) // single GPU\n",
287+
" .optDevices(manager.getEngine().getDevices(1)) // single GPU\n",
288288
" .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
289289
"\n",
290290
"Trainer trainer = model.newTrainer(config);"

chapter_linear-networks/softmax-regression-djl.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@
245245
"source": [
246246
"DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
247247
" .optOptimizer(sgd) // Optimizer\n",
248-
" .optDevices(Engine.getInstance().getDevices(1)) // single GPU\n",
248+
" .optDevices(manager.getEngine().getDevices(1)) // single GPU\n",
249249
" .addEvaluator(new Accuracy()) // Model Accuracy\n",
250250
" .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
251251
"\n",
@@ -307,7 +307,7 @@
307307
"int numEpochs = 3;\n",
308308
"\n",
309309
"EasyTrain.fit(trainer, numEpochs, trainingSet, validationSet);\n",
310-
"trainer.getTrainingResult().getValidateEvaluation(\"Accuracy\")"
310+
"var result = trainer.getTrainingResult();"
311311
]
312312
},
313313
{

chapter_linear-networks/softmax-regression-scratch.ipynb

-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
" .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
6767
" .build();\n",
6868
"\n",
69-
"\n",
7069
"FashionMnist validationSet = FashionMnist.builder()\n",
7170
" .optUsage(Dataset.Usage.TEST)\n",
7271
" .setSampling(batchSize, false)\n",

chapter_natural-language-processing-pretraining/word-embedding-dataset.ipynb

-6
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@
4040
"metadata": {},
4141
"outputs": [],
4242
"source": [
43-
"import ai.djl.Device;\n",
44-
"import ai.djl.Model;\n",
45-
"import ai.djl.engine.Engine;\n",
46-
"import ai.djl.ndarray.*;\n",
47-
"import ai.djl.ndarray.index.NDIndex;\n",
48-
"\n",
4943
"import java.util.stream.*;\n",
5044
"import org.apache.commons.math3.distribution.EnumeratedDistribution;"
5145
]

chapter_recurrent-modern/machine-translation-and-dataset.ipynb

+30-48
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"metadata": {},
9191
"outputs": [],
9292
"source": [
93+
"import java.nio.charset.*;\n",
9394
"import java.util.zip.*;\n",
9495
"import java.util.stream.*;"
9596
]
@@ -132,38 +133,22 @@
132133
"metadata": {},
133134
"outputs": [],
134135
"source": [
135-
"public static StringBuilder readDataNMT() throws IOException {\n",
136-
" File file = new File(\"./fra-eng.zip\");\n",
137-
" if (!file.exists()) {\n",
138-
" InputStream inputStream =\n",
139-
" new URL(\"http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip\").openStream();\n",
140-
" Files.copy(\n",
141-
" inputStream, Paths.get(\"./fra-eng.zip\"), StandardCopyOption.REPLACE_EXISTING);\n",
142-
" }\n",
143-
"\n",
144-
" ZipFile zipFile = new ZipFile(file);\n",
136+
"public static String readDataNMT() throws IOException {\n",
137+
" DownloadUtils.download(\n",
138+
" \"http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip\", \"fra-eng.zip\");\n",
139+
" ZipFile zipFile = new ZipFile(new File(\"fra-eng.zip\"));\n",
145140
" Enumeration<? extends ZipEntry> entries = zipFile.entries();\n",
146-
" InputStream stream = null;\n",
147141
" while (entries.hasMoreElements()) {\n",
148142
" ZipEntry entry = entries.nextElement();\n",
149143
" if (entry.getName().contains(\"fra.txt\")) {\n",
150-
" stream = zipFile.getInputStream(entry);\n",
151-
" break;\n",
144+
" InputStream stream = zipFile.getInputStream(entry);\n",
145+
" return new String(stream.readAllBytes(), StandardCharsets.UTF_8);\n",
152146
" }\n",
153147
" }\n",
154-
"\n",
155-
" String[] lines;\n",
156-
" try (BufferedReader in = new BufferedReader(new InputStreamReader(stream))) {\n",
157-
" lines = in.lines().toArray(String[]::new);\n",
158-
" }\n",
159-
" StringBuilder output = new StringBuilder();\n",
160-
" for (int i = 0; i < lines.length; i++) {\n",
161-
" output.append(lines[i] + \"\\n\");\n",
162-
" }\n",
163-
" return output;\n",
148+
" return null;\n",
164149
"}\n",
165150
"\n",
166-
"StringBuilder rawText = readDataNMT();\n",
151+
"String rawText = readDataNMT();\n",
167152
"System.out.println(rawText.substring(0, 75));"
168153
]
169154
},
@@ -188,7 +173,7 @@
188173
"metadata": {},
189174
"outputs": [],
190175
"source": [
191-
"public static StringBuilder preprocessNMT(String text) {\n",
176+
"public static String preprocessNMT(String text) {\n",
192177
" // Replace non-breaking space with space, and convert uppercase letters to\n",
193178
" // lowercase ones\n",
194179
"\n",
@@ -204,7 +189,7 @@
204189
" }\n",
205190
" out.append(currChar);\n",
206191
" }\n",
207-
" return out;\n",
192+
" return out.toString();\n",
208193
"}\n",
209194
"\n",
210195
"public static boolean noSpace(Character currChar, Character prevChar) {\n",
@@ -213,7 +198,7 @@
213198
" && prevChar != ' ';\n",
214199
"}\n",
215200
"\n",
216-
"StringBuilder text = preprocessNMT(rawText.toString());\n",
201+
"String text = preprocessNMT(rawText);\n",
217202
"System.out.println(text.substring(0, 80));"
218203
]
219204
},
@@ -281,7 +266,9 @@
281266
"metadata": {},
282267
"outputs": [],
283268
"source": [
284-
"for (String[] subArr : target.subList(0, 6)) System.out.println(Arrays.toString(subArr));"
269+
"for (String[] subArr : target.subList(0, 6)) {\n",
270+
" System.out.println(Arrays.toString(subArr));\n",
271+
"}"
285272
]
286273
},
287274
{
@@ -407,9 +394,11 @@
407394
"public static int[] truncatePad(Integer[] integerLine, int numSteps, int paddingToken) {\n",
408395
" /* Truncate or pad sequences */\n",
409396
" int[] line = Arrays.stream(integerLine).mapToInt(i -> i).toArray();\n",
410-
" if (line.length > numSteps) return Arrays.copyOfRange(line, 0, numSteps);\n",
397+
" if (line.length > numSteps) {\n",
398+
" return Arrays.copyOfRange(line, 0, numSteps);\n",
399+
" }\n",
411400
" int[] paddingTokenArr = new int[numSteps - line.length]; // Pad\n",
412-
" for (int i = 0; i < paddingTokenArr.length; i++) paddingTokenArr[i] = paddingToken;\n",
401+
" Arrays.fill(paddingTokenArr, paddingToken);\n",
413402
"\n",
414403
" return IntStream.concat(Arrays.stream(line), Arrays.stream(paddingTokenArr)).toArray();\n",
415404
"}\n",
@@ -451,19 +440,20 @@
451440
"outputs": [],
452441
"source": [
453442
"public static Pair<NDArray, NDArray> buildArrayNMT(\n",
454-
" ArrayList<String[]> lines, Vocab vocab, int numSteps) {\n",
443+
" List<String[]> lines, Vocab vocab, int numSteps) {\n",
455444
" /* Transform text sequences of machine translation into minibatches. */\n",
456445
" List<Integer[]> linesIntArr = new ArrayList<>();\n",
457-
" for (int i = 0; i < lines.size(); i++) {\n",
458-
" linesIntArr.add(vocab.getIdxs(lines.get(i)));\n",
446+
" for (String[] strings : lines) {\n",
447+
" linesIntArr.add(vocab.getIdxs(strings));\n",
459448
" }\n",
460449
" for (int i = 0; i < linesIntArr.size(); i++) {\n",
461-
" ArrayList<Integer> temp = new ArrayList<>();\n",
462-
" temp.addAll(Arrays.asList(linesIntArr.get(i)));\n",
450+
" List<Integer> temp = new ArrayList<>(Arrays.asList(linesIntArr.get(i)));\n",
463451
" temp.add(vocab.getIdx(\"<eos>\"));\n",
464-
" linesIntArr.set(i, temp.stream().toArray(n -> new Integer[n]));\n",
452+
" linesIntArr.set(i, temp.toArray(new Integer[0]));\n",
465453
" }\n",
466454
"\n",
455+
" NDManager manager = NDManager.newBaseManager();\n",
456+
"\n",
467457
" NDArray arr = manager.create(new Shape(linesIntArr.size(), numSteps), DataType.INT32);\n",
468458
" int row = 0;\n",
469459
" for (Integer[] line : linesIntArr) {\n",
@@ -498,19 +488,18 @@
498488
"public static Pair<ArrayDataset, Pair<Vocab, Vocab>> loadDataNMT(\n",
499489
" int batchSize, int numSteps, int numExamples) throws IOException {\n",
500490
" /* Return the iterator and the vocabularies of the translation dataset. */\n",
501-
" StringBuilder text = preprocessNMT(readDataNMT().toString());\n",
502-
" Pair<ArrayList<String[]>, ArrayList<String[]>> pair =\n",
503-
" tokenizeNMT(text.toString(), numExamples);\n",
491+
" String text = preprocessNMT(readDataNMT());\n",
492+
" Pair<ArrayList<String[]>, ArrayList<String[]>> pair = tokenizeNMT(text, numExamples);\n",
504493
" ArrayList<String[]> source = pair.getKey();\n",
505494
" ArrayList<String[]> target = pair.getValue();\n",
506495
" Vocab srcVocab =\n",
507496
" new Vocab(\n",
508-
" source.stream().toArray(String[][]::new),\n",
497+
" source.toArray(String[][]::new),\n",
509498
" 2,\n",
510499
" new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
511500
" Vocab tgtVocab =\n",
512501
" new Vocab(\n",
513-
" target.stream().toArray(String[][]::new),\n",
502+
" target.toArray(String[][]::new),\n",
514503
" 2,\n",
515504
" new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
516505
"\n",
@@ -582,13 +571,6 @@
582571
"1. Try different values of the `numExamples` argument in the `loadDataNMT` function. How does this affect the vocabulary sizes of the source language and the target language?\n",
583572
"1. Text in some languages such as Chinese and Japanese does not have word boundary indicators (e.g., space). Is word-level tokenization still a good idea for such cases? Why or why not?\n"
584573
]
585-
},
586-
{
587-
"cell_type": "code",
588-
"execution_count": null,
589-
"metadata": {},
590-
"outputs": [],
591-
"source": []
592574
}
593575
],
594576
"metadata": {

0 commit comments

Comments
 (0)