Skip to content

Commit 09afba5

Browse files
authored
Upgrade to 0.11.0 (#139)
1 parent a9b4deb commit 09afba5

File tree

104 files changed

+2823
-4954
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+2823
-4954
lines changed

.github/workflows/pr_notebook.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ jobs:
5555
./gradlew installKernel
5656
- name: test Notebook
5757
run: |
58+
export DATASET_LIMIT=512
59+
export MAX_EPOCH=2
5860
bash test_notebook.sh $${{ matrix.group }}
5961
- name: generated Notebook in html
6062
uses: actions/upload-artifact@v1
6163
if: always()
6264
with:
6365
name: notebook
64-
path: test_output/
66+
path: test_output/

chapter_attention-mechanisms/attention-cues.ipynb

+48-72
Original file line numberDiff line numberDiff line change
@@ -172,42 +172,11 @@
172172
"metadata": {},
173173
"outputs": [],
174174
"source": [
175-
"%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
176-
"\n",
177-
"%maven ai.djl:api:0.10.0\n",
178-
"%maven org.slf4j:slf4j-api:1.7.26\n",
179-
"%maven org.slf4j:slf4j-simple:1.7.26\n",
180-
"\n",
181-
"%maven ai.djl.mxnet:mxnet-engine:0.10.0\n",
182-
"%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
183-
]
184-
},
185-
{
186-
"cell_type": "code",
187-
"execution_count": null,
188-
"metadata": {},
189-
"outputs": [],
190-
"source": [
175+
"%load ../utils/djl-imports\n",
191176
"%load ../utils/plot-utils\n",
192177
"%load ../utils/Functions.java"
193178
]
194179
},
195-
{
196-
"cell_type": "code",
197-
"execution_count": null,
198-
"metadata": {},
199-
"outputs": [],
200-
"source": [
201-
"import ai.djl.ndarray.NDArray;\n",
202-
"import ai.djl.ndarray.NDManager;\n",
203-
"import ai.djl.ndarray.types.Shape;\n",
204-
"import ai.djl.translate.TranslateException;\n",
205-
"import tech.tablesaw.plotly.Plot;\n",
206-
"import tech.tablesaw.plotly.components.*;\n",
207-
"import tech.tablesaw.plotly.traces.HeatmapTrace;\n",
208-
"import tech.tablesaw.plotly.traces.Trace;"
209-
]
210-
},
211180
{
212181
"cell_type": "code",
213182
"execution_count": null,
@@ -241,48 +210,48 @@
241210
" String[] titles,\n",
242211
" int width,\n",
243212
" int height) {\n",
244-
" int numRows = (int) matrices.getShape().get(0);\n",
245-
" int numCols = (int) matrices.getShape().get(1);\n",
213+
" int numRows = (int) matrices.getShape().get(0);\n",
214+
" int numCols = (int) matrices.getShape().get(1);\n",
246215
"\n",
247-
" Trace[] traces = new Trace[numRows * numCols];\n",
248-
" int count = 0;\n",
249-
" for (int i = 0; i < numRows; i++) {\n",
250-
" for (int j = 0; j < numCols; j++) {\n",
251-
" NDArray NDMatrix = matrices.get(i).get(j);\n",
252-
" double[][] matrix =\n",
253-
" new double[(int) NDMatrix.getShape().get(0)]\n",
254-
" [(int) NDMatrix.getShape().get(1)];\n",
255-
" Object[] x = new Object[matrix.length];\n",
256-
" Object[] y = new Object[matrix.length];\n",
257-
" for (int k = 0; k < NDMatrix.getShape().get(0); k++) {\n",
258-
" matrix[k] = Functions.floatToDoubleArray(NDMatrix.get(k).toFloatArray());\n",
259-
" x[k] = k;\n",
260-
" y[k] = k;\n",
261-
" }\n",
262-
" HeatmapTrace.HeatmapBuilder builder = HeatmapTrace.builder(x, y, matrix);\n",
263-
" if (titles != null) {\n",
264-
" builder = (HeatmapTrace.HeatmapBuilder) builder.name(titles[j]);\n",
265-
" }\n",
266-
" traces[count++] = builder.build();\n",
216+
" Trace[] traces = new Trace[numRows * numCols];\n",
217+
" int count = 0;\n",
218+
" for (int i = 0; i < numRows; i++) {\n",
219+
" for (int j = 0; j < numCols; j++) {\n",
220+
" NDArray NDMatrix = matrices.get(i).get(j);\n",
221+
" double[][] matrix =\n",
222+
" new double[(int) NDMatrix.getShape().get(0)]\n",
223+
" [(int) NDMatrix.getShape().get(1)];\n",
224+
" Object[] x = new Object[matrix.length];\n",
225+
" Object[] y = new Object[matrix.length];\n",
226+
" for (int k = 0; k < NDMatrix.getShape().get(0); k++) {\n",
227+
" matrix[k] = Functions.floatToDoubleArray(NDMatrix.get(k).toFloatArray());\n",
228+
" x[k] = k;\n",
229+
" y[k] = k;\n",
230+
" }\n",
231+
" HeatmapTrace.HeatmapBuilder builder = HeatmapTrace.builder(x, y, matrix);\n",
232+
" if (titles != null) {\n",
233+
" builder = (HeatmapTrace.HeatmapBuilder) builder.name(titles[j]);\n",
267234
" }\n",
235+
" traces[count++] = builder.build();\n",
268236
" }\n",
269-
" Grid grid =\n",
270-
" Grid.builder()\n",
271-
" .columns(numCols)\n",
272-
" .rows(numRows)\n",
273-
" .pattern(Grid.Pattern.INDEPENDENT)\n",
274-
" .build();\n",
275-
" Layout layout =\n",
276-
" Layout.builder()\n",
277-
" .title(\"\")\n",
278-
" .xAxis(Axis.builder().title(xLabel).build())\n",
279-
" .yAxis(Axis.builder().title(yLabel).build())\n",
280-
" .width(width)\n",
281-
" .height(height)\n",
282-
" .grid(grid)\n",
283-
" .build();\n",
284-
" return new Figure(layout, traces);\n",
285-
" }"
237+
" }\n",
238+
" Grid grid =\n",
239+
" Grid.builder()\n",
240+
" .columns(numCols)\n",
241+
" .rows(numRows)\n",
242+
" .pattern(Grid.Pattern.INDEPENDENT)\n",
243+
" .build();\n",
244+
" Layout layout =\n",
245+
" Layout.builder()\n",
246+
" .title(\"\")\n",
247+
" .xAxis(Axis.builder().title(xLabel).build())\n",
248+
" .yAxis(Axis.builder().title(yLabel).build())\n",
249+
" .width(width)\n",
250+
" .height(height)\n",
251+
" .grid(grid)\n",
252+
" .build();\n",
253+
" return new Figure(layout, traces);\n",
254+
"}"
286255
]
287256
},
288257
{
@@ -328,6 +297,13 @@
328297
"1. What can be the volitional cue when decoding a sequence token by token in machine translation? What are the nonvolitional cues and the sensory inputs?\n",
329298
"1. Randomly generate a $10 \\times 10$ matrix and use the softmax operation to ensure each row is a valid probability distribution. Visualize the output attention weights.\n"
330299
]
300+
},
301+
{
302+
"cell_type": "code",
303+
"execution_count": null,
304+
"metadata": {},
305+
"outputs": [],
306+
"source": []
331307
}
332308
],
333309
"metadata": {
@@ -342,7 +318,7 @@
342318
"mimetype": "text/x-java-source",
343319
"name": "Java",
344320
"pygments_lexer": "java",
345-
"version": "11.0.10+9"
321+
"version": "14.0.2+12"
346322
}
347323
},
348324
"nbformat": 4,

chapter_attention-mechanisms/attention-scoring-functions.ipynb

+9-48
Original file line numberDiff line numberDiff line change
@@ -80,58 +80,12 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
84-
"\n",
85-
"%maven ai.djl:api:0.11.0\n",
86-
"%maven org.slf4j:slf4j-api:1.7.26\n",
87-
"%maven org.slf4j:slf4j-simple:1.7.26\n",
88-
"\n",
89-
"%maven ai.djl.mxnet:mxnet-engine:0.11.0\n",
90-
"%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
91-
]
92-
},
93-
{
94-
"cell_type": "code",
95-
"execution_count": null,
96-
"metadata": {},
97-
"outputs": [],
98-
"source": [
83+
"%load ../utils/djl-imports\n",
9984
"%load ../utils/plot-utils\n",
10085
"%load ../utils/Functions.java\n",
10186
"%load ../utils/PlotUtils.java"
10287
]
10388
},
104-
{
105-
"cell_type": "code",
106-
"execution_count": null,
107-
"metadata": {},
108-
"outputs": [],
109-
"source": [
110-
"import ai.djl.Model;\n",
111-
"import ai.djl.ndarray.*;\n",
112-
"import ai.djl.ndarray.types.DataType;\n",
113-
"import ai.djl.ndarray.types.Shape;\n",
114-
"import ai.djl.nn.AbstractBlock;\n",
115-
"import ai.djl.nn.Parameter;\n",
116-
"import ai.djl.training.*;\n",
117-
"import ai.djl.nn.core.Linear;\n",
118-
"import ai.djl.nn.norm.Dropout;\n",
119-
"import ai.djl.training.listener.TrainingListener;\n",
120-
"import ai.djl.training.loss.Loss;\n",
121-
"import ai.djl.training.optimizer.Optimizer;\n",
122-
"import ai.djl.training.tracker.Tracker;\n",
123-
"import ai.djl.training.ParameterStore;\n",
124-
"import ai.djl.training.initializer.UniformInitializer;\n",
125-
"import ai.djl.util.PairList;\n",
126-
"import ai.djl.translate.TranslateException;\n",
127-
"import tech.tablesaw.plotly.Plot;\n",
128-
"import tech.tablesaw.plotly.components.*;\n",
129-
"import tech.tablesaw.plotly.traces.ScatterTrace;\n",
130-
"\n",
131-
"import java.io.IOException;\n",
132-
"import java.util.function.Function;"
133-
]
134-
},
13589
{
13690
"cell_type": "code",
13791
"execution_count": null,
@@ -608,6 +562,13 @@
608562
"1. Using matrix multiplications only, can you design a new scoring function for queries and keys with different vector lengths?\n",
609563
"1. When queries and keys have the same vector length, is vector summation a better design than dot product for the scoring function? Why or why not?\n"
610564
]
565+
},
566+
{
567+
"cell_type": "code",
568+
"execution_count": null,
569+
"metadata": {},
570+
"outputs": [],
571+
"source": []
611572
}
612573
],
613574
"metadata": {
@@ -622,7 +583,7 @@
622583
"mimetype": "text/x-java-source",
623584
"name": "Java",
624585
"pygments_lexer": "java",
625-
"version": "11.0.10+9"
586+
"version": "14.0.2+12"
626587
}
627588
},
628589
"nbformat": 4,

chapter_attention-mechanisms/multihead-attention.ipynb

+7-35
Original file line numberDiff line numberDiff line change
@@ -93,43 +93,15 @@
9393
"metadata": {},
9494
"outputs": [],
9595
"source": [
96-
"%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
97-
"\n",
98-
"%maven ai.djl:api:0.11.0\n",
99-
"%maven org.slf4j:slf4j-api:1.7.26\n",
100-
"%maven org.slf4j:slf4j-simple:1.7.26\n",
101-
"\n",
102-
"%maven ai.djl.mxnet:mxnet-engine:0.11.0\n",
103-
"%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
104-
]
105-
},
106-
{
107-
"cell_type": "code",
108-
"execution_count": null,
109-
"metadata": {},
110-
"outputs": [],
111-
"source": [
96+
"%load ../utils/djl-imports\n",
11297
"%load ../utils/plot-utils\n",
11398
"%load ../utils/Functions.java\n",
11499
"%load ../utils/PlotUtils.java\n",
115-
"%load ../utils/AttentionUtils.java"
116-
]
117-
},
118-
{
119-
"cell_type": "code",
120-
"execution_count": null,
121-
"metadata": {},
122-
"outputs": [],
123-
"source": [
124-
"import ai.djl.ndarray.*;\n",
125-
"import ai.djl.ndarray.types.DataType;\n",
126-
"import ai.djl.ndarray.types.Shape;\n",
127-
"import ai.djl.nn.AbstractBlock;\n",
128-
"import ai.djl.nn.Parameter;\n",
129-
"import ai.djl.nn.core.Linear;\n",
130-
"import ai.djl.nn.norm.Dropout;\n",
131-
"import ai.djl.training.ParameterStore;\n",
132-
"import ai.djl.util.PairList;"
100+
"\n",
101+
"%load ../utils/attention/Chap10Utils.java\n",
102+
"%load ../utils/attention/DotProductAttention.java\n",
103+
"%load ../utils/attention/MultiHeadAttention.java\n",
104+
"%load ../utils/attention/PositionalEncoding.java"
133105
]
134106
},
135107
{
@@ -394,7 +366,7 @@
394366
"mimetype": "text/x-java-source",
395367
"name": "Java",
396368
"pygments_lexer": "java",
397-
"version": "11.0.10+9"
369+
"version": "14.0.2+12"
398370
}
399371
},
400372
"nbformat": 4,

chapter_attention-mechanisms/nadaraya-watson.ipynb

+9-46
Original file line numberDiff line numberDiff line change
@@ -32,57 +32,13 @@
3232
"metadata": {},
3333
"outputs": [],
3434
"source": [
35-
"%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
36-
"\n",
37-
"%maven ai.djl:api:0.11.0\n",
38-
"%maven org.slf4j:slf4j-api:1.7.26\n",
39-
"%maven org.slf4j:slf4j-simple:1.7.26\n",
40-
"\n",
41-
"%maven ai.djl.mxnet:mxnet-engine:0.11.0\n",
42-
"%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
43-
]
44-
},
45-
{
46-
"cell_type": "code",
47-
"execution_count": null,
48-
"metadata": {},
49-
"outputs": [],
50-
"source": [
35+
"%load ../utils/djl-imports\n",
5136
"%load ../utils/plot-utils\n",
5237
"%load ../utils/Functions.java\n",
5338
"%load ../utils/Animator.java\n",
5439
"%load ../utils/PlotUtils.java"
5540
]
5641
},
57-
{
58-
"cell_type": "code",
59-
"execution_count": null,
60-
"metadata": {},
61-
"outputs": [],
62-
"source": [
63-
"import ai.djl.Model;\n",
64-
"import ai.djl.ndarray.*;\n",
65-
"import ai.djl.ndarray.types.DataType;\n",
66-
"import ai.djl.ndarray.types.Shape;\n",
67-
"import ai.djl.nn.AbstractBlock;\n",
68-
"import ai.djl.nn.Parameter;\n",
69-
"import ai.djl.training.*;\n",
70-
"import ai.djl.training.listener.TrainingListener;\n",
71-
"import ai.djl.training.loss.Loss;\n",
72-
"import ai.djl.training.optimizer.Optimizer;\n",
73-
"import ai.djl.training.tracker.Tracker;\n",
74-
"import ai.djl.training.ParameterStore;\n",
75-
"import ai.djl.training.initializer.UniformInitializer;\n",
76-
"import ai.djl.util.PairList;\n",
77-
"import ai.djl.translate.TranslateException;\n",
78-
"import tech.tablesaw.plotly.Plot;\n",
79-
"import tech.tablesaw.plotly.components.*;\n",
80-
"import tech.tablesaw.plotly.traces.ScatterTrace;\n",
81-
"\n",
82-
"import java.io.IOException;\n",
83-
"import java.util.function.Function;"
84-
]
85-
},
8642
{
8743
"cell_type": "code",
8844
"execution_count": null,
@@ -651,6 +607,13 @@
651607
"1. How can we add hyperparameters to nonparametric Nadaraya-Watson kernel regression to predict better?\n",
652608
"1. Design another parametric attention pooling for the kernel regression of this section. Train this new model and visualize its attention weights.\n"
653609
]
610+
},
611+
{
612+
"cell_type": "code",
613+
"execution_count": null,
614+
"metadata": {},
615+
"outputs": [],
616+
"source": []
654617
}
655618
],
656619
"metadata": {
@@ -665,7 +628,7 @@
665628
"mimetype": "text/x-java-source",
666629
"name": "Java",
667630
"pygments_lexer": "java",
668-
"version": "11.0.10+9"
631+
"version": "14.0.2+12"
669632
}
670633
},
671634
"nbformat": 4,

0 commit comments

Comments
 (0)