diff --git a/dataset.ipynb b/dataset.ipynb
new file mode 100644
index 0000000..f190f78
--- /dev/null
+++ b/dataset.ipynb
@@ -0,0 +1,1225 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-12-10 15:44:03.538568: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "2023-12-10 15:44:03.543441: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2023-12-10 15:44:03.619818: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2023-12-10 15:44:03.621088: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2023-12-10 15:44:05.873463: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading dataset...\n"
+ ]
+ }
+ ],
+ "source": [
+ "from load_fer2013 import load_fer2013\n",
+ "\n",
+ "\n",
+ "ds = load_fer2013()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " emotion | \n",
+ " pixels | \n",
+ " Usage | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Angry | \n",
+ " 70 80 82 72 58 58 60 63 54 58 60 48 89 115 121... | \n",
+ " Training | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Angry | \n",
+ " 151 150 147 155 148 133 111 140 170 174 182 15... | \n",
+ " Training | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Fear | \n",
+ " 231 212 156 164 174 138 161 173 182 200 106 38... | \n",
+ " Training | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Sad | \n",
+ " 24 32 36 30 32 23 19 20 30 41 21 22 32 34 21 1... | \n",
+ " Training | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Neutral | \n",
+ " 4 0 0 0 0 0 0 0 0 0 0 0 3 15 23 28 48 50 58 84... | \n",
+ " Training | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 35882 | \n",
+ " Neutral | \n",
+ " 50 36 17 22 23 29 33 39 34 37 37 37 39 43 48 5... | \n",
+ " PrivateTest | \n",
+ "
\n",
+ " \n",
+ " 35883 | \n",
+ " Happy | \n",
+ " 178 174 172 173 181 188 191 194 196 199 200 20... | \n",
+ " PrivateTest | \n",
+ "
\n",
+ " \n",
+ " 35884 | \n",
+ " Angry | \n",
+ " 17 17 16 23 28 22 19 17 25 26 20 24 31 19 27 9... | \n",
+ " PrivateTest | \n",
+ "
\n",
+ " \n",
+ " 35885 | \n",
+ " Happy | \n",
+ " 30 28 28 29 31 30 42 68 79 81 77 67 67 71 63 6... | \n",
+ " PrivateTest | \n",
+ "
\n",
+ " \n",
+ " 35886 | \n",
+ " Fear | \n",
+ " 19 13 14 12 13 16 21 33 50 57 71 84 97 108 122... | \n",
+ " PrivateTest | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
35887 rows × 3 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " emotion pixels Usage\n",
+ "0 Angry 70 80 82 72 58 58 60 63 54 58 60 48 89 115 121... Training\n",
+ "1 Angry 151 150 147 155 148 133 111 140 170 174 182 15... Training\n",
+ "2 Fear 231 212 156 164 174 138 161 173 182 200 106 38... Training\n",
+ "3 Sad 24 32 36 30 32 23 19 20 30 41 21 22 32 34 21 1... Training\n",
+ "4 Neutral 4 0 0 0 0 0 0 0 0 0 0 0 3 15 23 28 48 50 58 84... Training\n",
+ "... ... ... ...\n",
+ "35882 Neutral 50 36 17 22 23 29 33 39 34 37 37 37 39 43 48 5... PrivateTest\n",
+ "35883 Happy 178 174 172 173 181 188 191 194 196 199 200 20... PrivateTest\n",
+ "35884 Angry 17 17 16 23 28 22 19 17 25 26 20 24 31 19 27 9... PrivateTest\n",
+ "35885 Happy 30 28 28 29 31 30 42 68 79 81 77 67 67 71 63 6... PrivateTest\n",
+ "35886 Fear 19 13 14 12 13 16 21 33 50 57 71 84 97 108 122... PrivateTest\n",
+ "\n",
+ "[35887 rows x 3 columns]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# map emotion integer to [\"Angry\", \"Disgust\", \"Fear\", \"Happy\", \"Sad\", \"Surprise\", \"Neutral\"]\n",
+ "emotion_labels = {0: \"Angry\", 1: \"Disgust\", 2: \"Fear\", 3: \"Happy\", 4: \"Sad\", 5: \"Surprise\", 6: \"Neutral\"}\n",
+ "\n",
+ "# Replace integer values with corresponding emotion labels\n",
+ "ds['emotion'] = ds['emotion'].replace(emotion_labels)\n",
+ "ds"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/martin/.local/lib/python3.11/site-packages/pydantic/_internal/_fields.py:128: UserWarning:\n",
+ "\n",
+ "Field \"model_server_url\" has conflict with protected namespace \"model_\".\n",
+ "\n",
+ "You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.\n",
+ "\n",
+ "/home/martin/.local/lib/python3.11/site-packages/pydantic/_internal/_config.py:317: UserWarning:\n",
+ "\n",
+ "Valid config keys have changed in V2:\n",
+ "* 'schema_extra' has been renamed to 'json_schema_extra'\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from typing import Tuple\n",
+ "import mlflow.tensorflow\n",
+ "import tensorflow as tf\n",
+ "from load_fer2013 import load_fer2013, preprocess\n",
+ "\n",
+ "def load_and_preprocess_data() -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:\n",
+ " data = load_fer2013()\n",
+ " num_classes = 7\n",
+ "\n",
+ " # Define splits for train, validation, and test sets\n",
+ " split_train = int(len(data) * 0.7)\n",
+ " split_test = int(len(data) * 0.1)\n",
+ " split_val = len(data) - split_train - split_test\n",
+ "\n",
+ " # Create a TensorFlow dataset from the data\n",
+ " dataset = tf.data.Dataset.from_tensor_slices(dict(data))\n",
+ " dataset = dataset.map(\n",
+ " lambda row: preprocess(row, num_classes), num_parallel_calls=tf.data.AUTOTUNE\n",
+ " )\n",
+ "\n",
+ " # Partition the data into train, validation, and test sets\n",
+ " train_dataset = (\n",
+ " dataset.take(split_train).shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)\n",
+ " )\n",
+ " val_dataset = (\n",
+ " dataset.skip(split_train).take(split_val).batch(32).prefetch(tf.data.AUTOTUNE)\n",
+ " )\n",
+ " test_dataset = (\n",
+ " dataset.skip(split_train + split_val).batch(32).prefetch(tf.data.AUTOTUNE)\n",
+ " )\n",
+ "\n",
+ " return train_dataset, val_dataset, test_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading dataset...\n",
+ "Train dataset size: 25120 examples\n",
+ "Validation dataset size: 7179 examples\n",
+ "Test dataset size: 3588 examples\n"
+ ]
+ }
+ ],
+ "source": [
+ "def count_dataset_examples(dataset):\n",
+ " count = 0\n",
+ " for batch in dataset:\n",
+ " count += batch[0].shape[0]\n",
+ " return count\n",
+ "\n",
+ "def print_dataset_sizes(train_dataset, val_dataset, test_dataset):\n",
+ " train_size = count_dataset_examples(train_dataset)\n",
+ " val_size = count_dataset_examples(val_dataset)\n",
+ " test_size = count_dataset_examples(test_dataset)\n",
+ "\n",
+ " print(f\"Train dataset size: {train_size} examples\")\n",
+ " print(f\"Validation dataset size: {val_size} examples\")\n",
+ " print(f\"Test dataset size: {test_size} examples\")\n",
+ "\n",
+ "train_dataset, val_dataset, test_dataset = load_and_preprocess_data()\n",
+ "\n",
+ "print_dataset_sizes(train_dataset, val_dataset, test_dataset)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "alignmentgroup": "True",
+ "hovertemplate": "emotion=%{x}
Count=%{y}",
+ "legendgroup": "",
+ "marker": {
+ "color": "#636efa",
+ "pattern": {
+ "shape": ""
+ }
+ },
+ "name": "",
+ "offsetgroup": "",
+ "orientation": "v",
+ "showlegend": false,
+ "textposition": "auto",
+ "type": "bar",
+ "x": [
+ "Happy",
+ "Neutral",
+ "Sad",
+ "Fear",
+ "Angry",
+ "Surprise",
+ "Disgust"
+ ],
+ "xaxis": "x",
+ "y": [
+ 8989,
+ 6198,
+ 6077,
+ 5121,
+ 4953,
+ 4002,
+ 547
+ ],
+ "yaxis": "y"
+ }
+ ],
+ "layout": {
+ "barmode": "relative",
+ "legend": {
+ "tracegroupgap": 0
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Emotion Distribution"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "emotion"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Count"
+ }
+ }
+ }
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import plotly.express as px\n",
+ "\n",
+ "# Count the occurrences of each emotion\n",
+ "emotion_counts = ds['emotion'].value_counts()\n",
+ "\n",
+ "fig = px.bar(emotion_counts, \n",
+ " x=emotion_counts.index, \n",
+ " y=emotion_counts.values, \n",
+ " title='Emotion Distribution',\n",
+ " labels={'x': 'Emotion', 'y': 'Count'})\n",
+ "\n",
+ "# Show the plot\n",
+ "fig.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Baseline accuracy: 0.25\n"
+ ]
+ }
+ ],
+ "source": [
+ "emotion_counts = ds['emotion'].value_counts()\n",
+ "happy_count = emotion_counts['Happy']\n",
+ "\n",
+ "# Calculate the baseline accuracy\n",
+ "baseline_acc = happy_count / sum(emotion_counts)\n",
+ "print(f\"Baseline accuracy: {baseline_acc:.2f}\")\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/images/class_distribution.png b/images/class_distribution.png
new file mode 100644
index 0000000..055f1b2
Binary files /dev/null and b/images/class_distribution.png differ
diff --git a/images/image-1.png b/images/image-1.png
new file mode 100644
index 0000000..12a0d5f
Binary files /dev/null and b/images/image-1.png differ
diff --git a/images/image-2.png b/images/image-2.png
new file mode 100644
index 0000000..103ae78
Binary files /dev/null and b/images/image-2.png differ
diff --git a/images/image.png b/images/image.png
new file mode 100644
index 0000000..04e3594
Binary files /dev/null and b/images/image.png differ
diff --git a/images/lr1.png b/images/lr1.png
new file mode 100644
index 0000000..31027bf
Binary files /dev/null and b/images/lr1.png differ
diff --git a/images/lr2.png b/images/lr2.png
new file mode 100644
index 0000000..ad19f0e
Binary files /dev/null and b/images/lr2.png differ
diff --git a/images/val_acc1.png b/images/val_acc1.png
new file mode 100644
index 0000000..1e407c5
Binary files /dev/null and b/images/val_acc1.png differ
diff --git a/images/val_acc2.png b/images/val_acc2.png
new file mode 100644
index 0000000..a0fa434
Binary files /dev/null and b/images/val_acc2.png differ
diff --git a/images/val_acc3.png b/images/val_acc3.png
new file mode 100644
index 0000000..51bfb9c
Binary files /dev/null and b/images/val_acc3.png differ
diff --git a/images/val_loss1.png b/images/val_loss1.png
new file mode 100644
index 0000000..f4617ee
Binary files /dev/null and b/images/val_loss1.png differ
diff --git a/images/val_loss2.png b/images/val_loss2.png
new file mode 100644
index 0000000..972ab1f
Binary files /dev/null and b/images/val_loss2.png differ
diff --git a/load_fer2013.py b/load_fer2013.py
new file mode 100644
index 0000000..b806512
--- /dev/null
+++ b/load_fer2013.py
@@ -0,0 +1,38 @@
+import os
+import subprocess
+import pandas as pd
+import numpy as np
+import tensorflow as tf
+from typing import Tuple
+
+
+def load_fer2013() -> pd.DataFrame:
+ """Load the emotion dataset as a tf.data.Dataset."""
+ if not os.path.exists("fer2013"):
+ print("Downloading the face emotion dataset...")
+ subprocess.check_output(
+ "curl -SL https://www.dropbox.com/s/opuvvdv3uligypx/fer2013.tar | tar xz",
+ shell=True,
+ )
+ print("Loading dataset...")
+ data = pd.read_csv("fer2013/fer2013.csv")
+ return data
+
+
+def preprocess(row, num_classes):
+ # Convert the 'pixels' tensor to string and split
+ pixel_string = row["pixels"]
+ pixel_values = tf.strings.split([pixel_string], sep=" ")
+ pixel_values = tf.strings.to_number(pixel_values, out_type=tf.int32)
+
+ # Convert the RaggedTensor to a regular tensor
+ pixel_values = tf.RaggedTensor.to_tensor(pixel_values, default_value=0)
+
+ # Reshape and normalize the pixel values
+ pixels = tf.reshape(pixel_values, (48, 48, 1))
+ pixels = tf.cast(pixels, tf.float32) / 255.0
+
+ # Prepare the label
+ emotion = tf.one_hot(row["emotion"], depth=num_classes)
+
+ return pixels, emotion
diff --git a/mlflow-start.sh b/mlflow-start.sh
new file mode 100755
index 0000000..65ec463
--- /dev/null
+++ b/mlflow-start.sh
@@ -0,0 +1 @@
+mlflow server --host 127.0.0.1 --port 8080
\ No newline at end of file
diff --git a/readme.md b/readme.md
new file mode 100644
index 0000000..cbc4244
--- /dev/null
+++ b/readme.md
@@ -0,0 +1,113 @@
+# Facial Emotion Recognition
+
+This repository is my project hand-in for the AKT3 course on Deep Learning & Computer Vision.
+
+## Dataset
+
+For training this model we will be using the [FER2013](https://www.kaggle.com/datasets/msambare/fer2013) dataset.
+
+### Example Data
+
+The dataset contains 48x48 images of human faces.
+
+
+
+
+### Analysis
+
+
+
+Using the distribution we can determine a baseline accuracy.
+
+`HappyCounts / TotalCounts = 0.25`
+
+`Baseline accuracy = 25%`
+
+So by always guessing `Happy` we could reach an accuracy of 25%. Our goal is to improve that with the CNN.
+
+## Baseline
+
+Using the `train.py` script we are training a Facial emotion Recognition model that classifies images of human faces on 7 emotions (`"Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"`).
+
+We split up the dataset into train, validation and test data.
+
+* Train dataset size: 25120 examples
+* Validation dataset size: 7179 examples
+* Test dataset size: 3588 examples
+
+### Results
+
+
+
+As shown in the graphs above we achieve very poor performance with our baseline parameters.
+
+| Parameter | Value |
+|--------------------------|--------------------------|
+| learning_rate | 0.01 |
+| loss | categorical_crossentropy |
+| epochs | 50 |
+| batch_size | 128 |
+| early_stopping_patience | 7 |
+| lr_patience | 5 |
+| lr_reduction_factor | 0.1 |
+| optimizer | Adam |
+| num_classes | 7 |
+| input_shape | (48, 48, 1) |
+| shuffle | True |
+| restore_best_weights | True |
+
+## Experiment 1 - Improving validation-accuracy
+
+In my first run the model only achieved a validation accuracy of 21% which is very poor. I was confused because other resources showed me that on this dataset significantly higher validation accuracies with similar CNNs could be achieved.
+
+My hypothesis is that I chose a far to high starting learning rate which lead to very early convergence and therefore significant underfitting.
+By reducing the learning rate I expect better results.
+
+| Parameter | Value |
+|--------------------------|-----------------------|
+| learning_rate | 0.001 |
+
+### Results
+
+Validation Loss | Validation Accuracy | Learning Rate
+:------------------------:|:------------------------:|:-------------------------:
+ |  | 
+
+As we can see in the resulting charts my hypothesis was correct and by reducing the learning rate we achieve much better results.
+
+## Experiment 2 - Smoothing the validation-loss curve
+
+The new validation loss curve is very erratic. I want to make it smoother and reduce the bumpiness of the curve. For this I again will lower the learning rate by a factor of 10.
+
+| Parameter | Value |
+|--------------------------|-----------------------|
+| learning_rate | 0.0001 |
+
+### Results
+
+As we can see in the resulting charts my hypothesis was correct and by reducing the learning rate the curve is much less erratic.
+
+Validation Loss | Validation Accuracy | Learning Rate
+:------------------------:|:------------------------:|:-------------------------:
+ |  | 
+
+## Experiment 3 - Disabling the restore-best-weights option
+
+For some reason in the EarlyStopping callback the restore_best_weights option actually chooses a worse configuration of the model at the end. By disabling the option we want to prohibit that behaviour.
+
+``` python
+keras.callbacks.EarlyStopping(
+ patience=params["early_stopping_patience"],
+ restore_best_weights=False
+ ),
+```
+
+| Parameter | Value |
+|--------------------------|-----------------------|
+| restore_best_weights | False |
+
+### Results
+
+In the following graph we see that by disabling the `restore_best_weights` option we can actually keep the better model in the end.
+
+
\ No newline at end of file
diff --git a/readme.pdf b/readme.pdf
new file mode 100644
index 0000000..c28aaea
Binary files /dev/null and b/readme.pdf differ
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..00835fa
--- /dev/null
+++ b/train.py
@@ -0,0 +1,146 @@
+from typing import Tuple, Dict, List
+import mlflow
+import mlflow.tensorflow
+from datetime import datetime
+from tensorflow import keras
+from keras.layers import (
+ Dense,
+ Flatten,
+ Conv2D,
+ MaxPooling2D,
+ Dropout,
+ BatchNormalization,
+)
+from keras.models import Sequential
+from keras.utils import plot_model
+import tensorflow as tf
+from load_fer2013 import load_fer2013, preprocess
+
+
+def setup_mlflow() -> None:
+ mlflow.set_tracking_uri("http://127.0.0.1:8080")
+ experiment_name = "Baseline"
+ experiment_description = (
+ "This is a neural network for classifiying human emotions based on facial expressions."
+ "This experiment will create a baseline neural network for further experiments."
+ )
+ experiment_tags = {
+ "project_name": "facial-emotion-recognition",
+ "experiment_name": experiment_name,
+ "dataset": "fer2013",
+ "mlflow.note.content": experiment_description,
+ "date": datetime.now().strftime("%d.%m.%Y %H:%M"),
+ }
+ mlflow.set_experiment(experiment_name)
+ mlflow.set_experiment_tags(experiment_tags)
+ mlflow.tensorflow.autolog()
+
+
+def create_model(
+ input_shape: Tuple[int, int, int], num_classes: int, params
+) -> Sequential:
+ model = Sequential(
+ [
+ Conv2D(32, (3, 3), activation="relu", input_shape=input_shape),
+ MaxPooling2D(),
+ BatchNormalization(),
+ Conv2D(64, (3, 3), activation="relu"),
+ MaxPooling2D(),
+ BatchNormalization(),
+ Flatten(),
+ Dense(128, activation="relu"),
+ Dropout(0.5),
+ Dense(num_classes, activation="softmax"),
+ ]
+ )
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(learning_rate=params["learning_rate"]),
+ loss=tf.keras.losses.CategoricalCrossentropy(),
+ metrics=["accuracy", keras.metrics.CategoricalAccuracy()],
+ )
+ return model
+
+
+import tensorflow as tf
+from typing import Tuple
+
+def load_and_preprocess_data() -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
+ data = load_fer2013()
+ num_classes = 7
+
+ # Define splits for train, validation, and test sets
+ split_train = int(len(data) * 0.7)
+ split_test = int(len(data) * 0.1)
+ split_val = len(data) - split_train - split_test
+
+ # Create a TensorFlow dataset from the data
+ dataset = tf.data.Dataset.from_tensor_slices(dict(data))
+ dataset = dataset.map(
+ lambda row: preprocess(row, num_classes), num_parallel_calls=tf.data.AUTOTUNE
+ )
+
+ # Partition the data into train, validation, and test sets
+ train_dataset = (
+ dataset.take(split_train).shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
+ )
+ val_dataset = (
+ dataset.skip(split_train).take(split_val).batch(32).prefetch(tf.data.AUTOTUNE)
+ )
+ test_dataset = (
+ dataset.skip(split_train + split_val).batch(32).prefetch(tf.data.AUTOTUNE)
+ )
+
+ return train_dataset, val_dataset, test_dataset
+
+
+def train_and_log_model(
+ model: Sequential,
+ train_dataset: tf.data.Dataset,
+ val_dataset: tf.data.Dataset,
+ params: Dict[str, str | int | List[str]],
+) -> None:
+ model.fit(
+ train_dataset,
+ validation_data=val_dataset,
+ epochs=params["epochs"], # type: ignore
+ batch_size=params["batch_size"],
+ callbacks=[
+ keras.callbacks.EarlyStopping(
+ patience=params["early_stopping_patience"], # type: ignore
+ # restore_best_weights=True, Removing this stops the model from being far worse at the last step... Dont know why
+ ),
+ keras.callbacks.ModelCheckpoint("./output/best_model", save_best_only=True),
+ keras.callbacks.ReduceLROnPlateau(
+ factor=params["lr_reduction_factor"], patience=params["lr_patience"] # type: ignore
+ ),
+ ],
+ )
+ model.save("./output/emotion.h5")
+ mlflow.log_params(params)
+ plot_model(model, to_file="./output/model.png", show_shapes=True)
+ mlflow.log_artifact("./output/model.png")
+ model.save_weights("./output/model_weights/model_weights")
+ mlflow.log_artifact("./output/model_weights")
+
+
+if __name__ == "__main__":
+ setup_mlflow()
+ with mlflow.start_run() as run:
+ input_shape = (48, 48, 1)
+ num_classes = 7
+ params = {
+ "batch_size": 128,
+ "epochs": 50,
+ "input_shape": input_shape,
+ "num_classes": num_classes,
+ "optimizer": "adam",
+ "loss": "categorical_crossentropy",
+ "early_stopping_patience": 5,
+ "learning_rate": 0.0001,
+ "lr_reduction_factor": 0.1,
+ "lr_patience": 3,
+ }
+
+ train_dataset, val_dataset, test_dataset = load_and_preprocess_data()
+ model = create_model(input_shape, num_classes, params)
+ run = train_and_log_model(model, train_dataset, val_dataset, params)