Skip to content

Commit 235669d

Browse files
committed
refactor: improve code readability and formatting in notebooks
1 parent f406a1b commit 235669d

12 files changed

+1129
-1005
lines changed

lesson-1-introduction-to-transformer-neural-networks/demos/02-char-tokenizer.ipynb

+23-19
Original file line numberDiff line numberDiff line change
@@ -108,29 +108,33 @@
108108
"outputs": [],
109109
"source": [
110110
"class CharTokenizer:\n",
111-
" def __init__(self, vocabulary):\n",
112-
" self.token_id_for_char = {char: token_id for token_id, char in enumerate(vocabulary)}\n",
113-
" self.char_for_token_id = {token_id: char for token_id, char in enumerate(vocabulary)}\n",
111+
" def __init__(self, vocabulary):\n",
112+
" self.token_id_for_char = {\n",
113+
" char: token_id for token_id, char in enumerate(vocabulary)\n",
114+
" }\n",
115+
" self.char_for_token_id = {\n",
116+
" token_id: char for token_id, char in enumerate(vocabulary)\n",
117+
" }\n",
114118
"\n",
115-
" @staticmethod\n",
116-
" def train_from_text(text):\n",
117-
" vocabulary = set(text)\n",
118-
" return CharTokenizer(sorted(list(vocabulary)))\n",
119+
" @staticmethod\n",
120+
" def train_from_text(text):\n",
121+
" vocabulary = set(text)\n",
122+
" return CharTokenizer(sorted(list(vocabulary)))\n",
119123
"\n",
120-
" def encode(self, text):\n",
121-
" token_ids = []\n",
122-
" for char in text:\n",
123-
" token_ids.append(self.token_id_for_char[char])\n",
124-
" return torch.tensor(token_ids, dtype=torch.long)\n",
124+
" def encode(self, text):\n",
125+
" token_ids = []\n",
126+
" for char in text:\n",
127+
" token_ids.append(self.token_id_for_char[char])\n",
128+
" return torch.tensor(token_ids, dtype=torch.long)\n",
125129
"\n",
126-
" def decode(self, token_ids):\n",
127-
" chars = []\n",
128-
" for token_id in token_ids.tolist():\n",
129-
" chars.append(self.char_for_token_id[token_id])\n",
130-
" return ''.join(chars)\n",
130+
" def decode(self, token_ids):\n",
131+
" chars = []\n",
132+
" for token_id in token_ids.tolist():\n",
133+
" chars.append(self.char_for_token_id[token_id])\n",
134+
" return \"\".join(chars)\n",
131135
"\n",
132-
" def vocabulary_size(self):\n",
133-
" return len(self.token_id_for_char)"
136+
" def vocabulary_size(self):\n",
137+
" return len(self.token_id_for_char)"
134138
]
135139
},
136140
{

lesson-1-introduction-to-transformer-neural-networks/demos/03-implementing-the-attention-block.ipynb

+88-79
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 2,
21+
"execution_count": null,
2222
"metadata": {},
2323
"outputs": [],
2424
"source": [
2525
"from pathlib import Path\n",
2626
"\n",
27-
"text = Path('../data/tiny-shakespeare.txt').read_text()"
27+
"text = Path(\"../data/tiny-shakespeare.txt\").read_text()"
2828
]
2929
},
3030
{
@@ -88,38 +88,40 @@
8888
},
8989
{
9090
"cell_type": "code",
91-
"execution_count": 4,
91+
"execution_count": null,
9292
"metadata": {
9393
"id": "Ap_Ixr0M-0Yv"
9494
},
9595
"outputs": [],
9696
"source": [
97-
"\n",
9897
"class CharTokenizer:\n",
99-
" def __init__(self, vocabulary):\n",
100-
" self.token_id_for_char = {char: token_id for token_id, char in enumerate(vocabulary)}\n",
101-
" self.char_for_token_id = {token_id: char for token_id, char in enumerate(vocabulary)}\n",
98+
" def __init__(self, vocabulary):\n",
99+
" self.token_id_for_char = {\n",
100+
" char: token_id for token_id, char in enumerate(vocabulary)\n",
101+
" }\n",
102+
" self.char_for_token_id = {\n",
103+
" token_id: char for token_id, char in enumerate(vocabulary)\n",
104+
" }\n",
102105
"\n",
103-
" @staticmethod\n",
104-
" def train_from_text(text):\n",
105-
" vocabulary = set(text)\n",
106-
" return CharTokenizer(sorted(list(vocabulary)))\n",
106+
" @staticmethod\n",
107+
" def train_from_text(text):\n",
108+
" vocabulary = set(text)\n",
109+
" return CharTokenizer(sorted(list(vocabulary)))\n",
107110
"\n",
108-
" def encode(self, text):\n",
109-
" token_ids = []\n",
110-
" for char in text:\n",
111-
" token_ids.append(self.token_id_for_char[char])\n",
112-
" return torch.tensor(token_ids, dtype=torch.long)\n",
111+
" def encode(self, text):\n",
112+
" token_ids = []\n",
113+
" for char in text:\n",
114+
" token_ids.append(self.token_id_for_char[char])\n",
115+
" return torch.tensor(token_ids, dtype=torch.long)\n",
113116
"\n",
114-
" def decode(self, token_ids):\n",
115-
" chars = []\n",
116-
" for token_id in token_ids.tolist():\n",
117-
" chars.append(self.char_for_token_id[token_id])\n",
118-
" return ''.join(chars)\n",
117+
" def decode(self, token_ids):\n",
118+
" chars = []\n",
119+
" for token_id in token_ids.tolist():\n",
120+
" chars.append(self.char_for_token_id[token_id])\n",
121+
" return \"\".join(chars)\n",
119122
"\n",
120-
"\n",
121-
" def vocabulary_size(self):\n",
122-
" return len(self.token_id_for_char)"
123+
" def vocabulary_size(self):\n",
124+
" return len(self.token_id_for_char)"
123125
]
124126
},
125127
{
@@ -175,84 +177,91 @@
175177
},
176178
{
177179
"cell_type": "code",
178-
"execution_count": 8,
180+
"execution_count": null,
179181
"metadata": {
180182
"id": "7Qal76ig-94U"
181183
},
182184
"outputs": [],
183185
"source": [
184186
"from torch.utils.data import Dataset\n",
185187
"\n",
188+
"\n",
186189
"class TokenIdsDataset(Dataset):\n",
187-
" def __init__(self, data, block_size):\n",
188-
" self.data = data\n",
189-
" self.block_size = block_size\n",
190+
" def __init__(self, data, block_size):\n",
191+
" self.data = data\n",
192+
" self.block_size = block_size\n",
190193
"\n",
191-
" def __len__(self):\n",
192-
" return len(self.data) - self.block_size\n",
194+
" def __len__(self):\n",
195+
" return len(self.data) - self.block_size\n",
193196
"\n",
194-
" def __getitem__(self, pos):\n",
195-
" assert pos < len(self.data) - self.block_size\n",
197+
" def __getitem__(self, pos):\n",
198+
" assert pos < len(self.data) - self.block_size\n",
196199
"\n",
197-
" x = self.data[pos:pos + self.block_size]\n",
198-
" y = self.data[pos + 1:pos + 1 + self.block_size]\n",
199-
" return x, y"
200+
" x = self.data[pos : pos + self.block_size]\n",
201+
" y = self.data[pos + 1 : pos + 1 + self.block_size]\n",
202+
" return x, y"
200203
]
201204
},
202205
{
203206
"cell_type": "code",
204-
"execution_count": 10,
207+
"execution_count": null,
205208
"metadata": {},
206209
"outputs": [],
207210
"source": [
208211
"config = {\n",
209-
" \"vocabulary_size\": tokenizer.vocabulary_size(),\n",
210-
" \"context_size\": 256,\n",
211-
" \"embedding_dim\": 768,\n",
212-
" \"heads_num\": 12,\n",
213-
" \"layers_num\": 10,\n",
214-
" \"dropout_rate\": 0.1,\n",
215-
" \"use_bias\": False,\n",
212+
" \"vocabulary_size\": tokenizer.vocabulary_size(),\n",
213+
" \"context_size\": 256,\n",
214+
" \"embedding_dim\": 768,\n",
215+
" \"heads_num\": 12,\n",
216+
" \"layers_num\": 10,\n",
217+
" \"dropout_rate\": 0.1,\n",
218+
" \"use_bias\": False,\n",
216219
"}\n",
217220
"\n",
218221
"config[\"head_size\"] = config[\"embedding_dim\"] // config[\"heads_num\"]"
219222
]
220223
},
221224
{
222225
"cell_type": "code",
223-
"execution_count": 11,
226+
"execution_count": null,
224227
"metadata": {},
225228
"outputs": [],
226229
"source": [
227230
"class AttentionHead(nn.Module):\n",
228-
" def __init__(self, config):\n",
229-
" super().__init__()\n",
230-
" self.Q_weights = nn.Linear(config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"])\n",
231-
" self.K_weights = nn.Linear(config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"])\n",
232-
" self.V_weights = nn.Linear(config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"])\n",
231+
" def __init__(self, config):\n",
232+
" super().__init__()\n",
233+
" self.Q_weights = nn.Linear(\n",
234+
" config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"]\n",
235+
" )\n",
236+
" self.K_weights = nn.Linear(\n",
237+
" config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"]\n",
238+
" )\n",
239+
" self.V_weights = nn.Linear(\n",
240+
" config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"]\n",
241+
" )\n",
233242
"\n",
234-
" self.dropout = nn.Dropout(config[\"dropout_rate\"])\n",
243+
" self.dropout = nn.Dropout(config[\"dropout_rate\"])\n",
235244
"\n",
236-
" casual_attention_mask = torch.tril(torch.ones(config[\"context_size\"], config[\"context_size\"]))\n",
237-
" self.register_buffer('casual_attention_mask', casual_attention_mask)\n",
245+
" casual_attention_mask = torch.tril(\n",
246+
" torch.ones(config[\"context_size\"], config[\"context_size\"])\n",
247+
" )\n",
248+
" self.register_buffer(\"casual_attention_mask\", casual_attention_mask)\n",
238249
"\n",
250+
" def forward(self, input): # (B, C, embedding_dim)\n",
251+
" batch_size, tokens_num, embedding_dim = input.shape\n",
252+
" Q = self.Q_weights(input) # (B, C, head_size)\n",
253+
" K = self.K_weights(input) # (B, C, head_size)\n",
254+
" V = self.V_weights(input) # (B, C, head_size)\n",
239255
"\n",
240-
" def forward(self, input): # (B, C, embedding_dim)\n",
241-
" batch_size, tokens_num, embedding_dim = input.shape\n",
242-
" Q = self.Q_weights(input) # (B, C, head_size)\n",
243-
" K = self.K_weights(input) # (B, C, head_size)\n",
244-
" V = self.V_weights(input) # (B, C, head_size)\n",
256+
" attention_scores = Q @ K.transpose(1, 2) # (B, C, C)\n",
257+
" attention_scores = attention_scores.masked_fill(\n",
258+
" self.casual_attention_mask[:tokens_num, :tokens_num] == 0, -torch.inf\n",
259+
" )\n",
260+
" attention_scores = attention_scores / (K.shape[-1] ** 0.5)\n",
261+
" attention_scores = torch.softmax(attention_scores, dim=-1)\n",
262+
" attention_scores = self.dropout(attention_scores)\n",
245263
"\n",
246-
" attention_scores = Q @ K.transpose(1, 2) # (B, C, C)\n",
247-
" attention_scores = attention_scores.masked_fill(\n",
248-
" self.casual_attention_mask[:tokens_num,:tokens_num] == 0,\n",
249-
" -torch.inf\n",
250-
" )\n",
251-
" attention_scores = attention_scores / ( K.shape[-1] ** 0.5 )\n",
252-
" attention_scores = torch.softmax(attention_scores, dim=-1)\n",
253-
" attention_scores = self.dropout(attention_scores)\n",
254-
"\n",
255-
" return attention_scores @ V # (B, C, head_size)"
264+
" return attention_scores @ V # (B, C, head_size)"
256265
]
257266
},
258267
{
@@ -304,27 +313,27 @@
304313
},
305314
{
306315
"cell_type": "code",
307-
"execution_count": 16,
316+
"execution_count": null,
308317
"metadata": {},
309318
"outputs": [],
310319
"source": [
311320
"class MultiHeadAttention(nn.Module):\n",
312-
" def __init__(self, config):\n",
313-
" super().__init__()\n",
321+
" def __init__(self, config):\n",
322+
" super().__init__()\n",
314323
"\n",
315-
" heads_list = [AttentionHead(config) for _ in range(config[\"heads_num\"])]\n",
316-
" self.heads = nn.ModuleList(heads_list)\n",
324+
" heads_list = [AttentionHead(config) for _ in range(config[\"heads_num\"])]\n",
325+
" self.heads = nn.ModuleList(heads_list)\n",
317326
"\n",
318-
" self.linear = nn.Linear(config[\"embedding_dim\"], config[\"embedding_dim\"])\n",
319-
" self.dropout = nn.Dropout(config[\"dropout_rate\"])\n",
327+
" self.linear = nn.Linear(config[\"embedding_dim\"], config[\"embedding_dim\"])\n",
328+
" self.dropout = nn.Dropout(config[\"dropout_rate\"])\n",
320329
"\n",
321-
" def forward(self, input):\n",
322-
" heads_outputs = [head(input) for head in self.heads]\n",
330+
" def forward(self, input):\n",
331+
" heads_outputs = [head(input) for head in self.heads]\n",
323332
"\n",
324-
" scores_change = torch.cat(heads_outputs, dim=-1)\n",
333+
" scores_change = torch.cat(heads_outputs, dim=-1)\n",
325334
"\n",
326-
" scores_change = self.linear(scores_change)\n",
327-
" return self.dropout(scores_change)"
335+
" scores_change = self.linear(scores_change)\n",
336+
" return self.dropout(scores_change)"
328337
]
329338
},
330339
{

0 commit comments

Comments
 (0)