|
18 | 18 | },
|
19 | 19 | {
|
20 | 20 | "cell_type": "code",
|
21 |
| - "execution_count": 2, |
| 21 | + "execution_count": null, |
22 | 22 | "metadata": {},
|
23 | 23 | "outputs": [],
|
24 | 24 | "source": [
|
25 | 25 | "from pathlib import Path\n",
|
26 | 26 | "\n",
|
27 |
| - "text = Path('../data/tiny-shakespeare.txt').read_text()" |
| 27 | + "text = Path(\"../data/tiny-shakespeare.txt\").read_text()" |
28 | 28 | ]
|
29 | 29 | },
|
30 | 30 | {
|
|
88 | 88 | },
|
89 | 89 | {
|
90 | 90 | "cell_type": "code",
|
91 |
| - "execution_count": 4, |
| 91 | + "execution_count": null, |
92 | 92 | "metadata": {
|
93 | 93 | "id": "Ap_Ixr0M-0Yv"
|
94 | 94 | },
|
95 | 95 | "outputs": [],
|
96 | 96 | "source": [
|
97 |
| - "\n", |
98 | 97 | "class CharTokenizer:\n",
|
99 |
| - " def __init__(self, vocabulary):\n", |
100 |
| - " self.token_id_for_char = {char: token_id for token_id, char in enumerate(vocabulary)}\n", |
101 |
| - " self.char_for_token_id = {token_id: char for token_id, char in enumerate(vocabulary)}\n", |
| 98 | + " def __init__(self, vocabulary):\n", |
| 99 | + " self.token_id_for_char = {\n", |
| 100 | + " char: token_id for token_id, char in enumerate(vocabulary)\n", |
| 101 | + " }\n", |
| 102 | + " self.char_for_token_id = {\n", |
| 103 | + " token_id: char for token_id, char in enumerate(vocabulary)\n", |
| 104 | + " }\n", |
102 | 105 | "\n",
|
103 |
| - " @staticmethod\n", |
104 |
| - " def train_from_text(text):\n", |
105 |
| - " vocabulary = set(text)\n", |
106 |
| - " return CharTokenizer(sorted(list(vocabulary)))\n", |
| 106 | + " @staticmethod\n", |
| 107 | + " def train_from_text(text):\n", |
| 108 | + " vocabulary = set(text)\n", |
| 109 | + " return CharTokenizer(sorted(list(vocabulary)))\n", |
107 | 110 | "\n",
|
108 |
| - " def encode(self, text):\n", |
109 |
| - " token_ids = []\n", |
110 |
| - " for char in text:\n", |
111 |
| - " token_ids.append(self.token_id_for_char[char])\n", |
112 |
| - " return torch.tensor(token_ids, dtype=torch.long)\n", |
| 111 | + " def encode(self, text):\n", |
| 112 | + " token_ids = []\n", |
| 113 | + " for char in text:\n", |
| 114 | + " token_ids.append(self.token_id_for_char[char])\n", |
| 115 | + " return torch.tensor(token_ids, dtype=torch.long)\n", |
113 | 116 | "\n",
|
114 |
| - " def decode(self, token_ids):\n", |
115 |
| - " chars = []\n", |
116 |
| - " for token_id in token_ids.tolist():\n", |
117 |
| - " chars.append(self.char_for_token_id[token_id])\n", |
118 |
| - " return ''.join(chars)\n", |
| 117 | + " def decode(self, token_ids):\n", |
| 118 | + " chars = []\n", |
| 119 | + " for token_id in token_ids.tolist():\n", |
| 120 | + " chars.append(self.char_for_token_id[token_id])\n", |
| 121 | + " return \"\".join(chars)\n", |
119 | 122 | "\n",
|
120 |
| - "\n", |
121 |
| - " def vocabulary_size(self):\n", |
122 |
| - " return len(self.token_id_for_char)" |
| 123 | + " def vocabulary_size(self):\n", |
| 124 | + " return len(self.token_id_for_char)" |
123 | 125 | ]
|
124 | 126 | },
|
125 | 127 | {
|
|
175 | 177 | },
|
176 | 178 | {
|
177 | 179 | "cell_type": "code",
|
178 |
| - "execution_count": 8, |
| 180 | + "execution_count": null, |
179 | 181 | "metadata": {
|
180 | 182 | "id": "7Qal76ig-94U"
|
181 | 183 | },
|
182 | 184 | "outputs": [],
|
183 | 185 | "source": [
|
184 | 186 | "from torch.utils.data import Dataset\n",
|
185 | 187 | "\n",
|
| 188 | + "\n", |
186 | 189 | "class TokenIdsDataset(Dataset):\n",
|
187 |
| - " def __init__(self, data, block_size):\n", |
188 |
| - " self.data = data\n", |
189 |
| - " self.block_size = block_size\n", |
| 190 | + " def __init__(self, data, block_size):\n", |
| 191 | + " self.data = data\n", |
| 192 | + " self.block_size = block_size\n", |
190 | 193 | "\n",
|
191 |
| - " def __len__(self):\n", |
192 |
| - " return len(self.data) - self.block_size\n", |
| 194 | + " def __len__(self):\n", |
| 195 | + " return len(self.data) - self.block_size\n", |
193 | 196 | "\n",
|
194 |
| - " def __getitem__(self, pos):\n", |
195 |
| - " assert pos < len(self.data) - self.block_size\n", |
| 197 | + " def __getitem__(self, pos):\n", |
| 198 | + " assert pos < len(self.data) - self.block_size\n", |
196 | 199 | "\n",
|
197 |
| - " x = self.data[pos:pos + self.block_size]\n", |
198 |
| - " y = self.data[pos + 1:pos + 1 + self.block_size]\n", |
199 |
| - " return x, y" |
| 200 | + " x = self.data[pos : pos + self.block_size]\n", |
| 201 | + " y = self.data[pos + 1 : pos + 1 + self.block_size]\n", |
| 202 | + " return x, y" |
200 | 203 | ]
|
201 | 204 | },
|
202 | 205 | {
|
203 | 206 | "cell_type": "code",
|
204 |
| - "execution_count": 10, |
| 207 | + "execution_count": null, |
205 | 208 | "metadata": {},
|
206 | 209 | "outputs": [],
|
207 | 210 | "source": [
|
208 | 211 | "config = {\n",
|
209 |
| - " \"vocabulary_size\": tokenizer.vocabulary_size(),\n", |
210 |
| - " \"context_size\": 256,\n", |
211 |
| - " \"embedding_dim\": 768,\n", |
212 |
| - " \"heads_num\": 12,\n", |
213 |
| - " \"layers_num\": 10,\n", |
214 |
| - " \"dropout_rate\": 0.1,\n", |
215 |
| - " \"use_bias\": False,\n", |
| 212 | + " \"vocabulary_size\": tokenizer.vocabulary_size(),\n", |
| 213 | + " \"context_size\": 256,\n", |
| 214 | + " \"embedding_dim\": 768,\n", |
| 215 | + " \"heads_num\": 12,\n", |
| 216 | + " \"layers_num\": 10,\n", |
| 217 | + " \"dropout_rate\": 0.1,\n", |
| 218 | + " \"use_bias\": False,\n", |
216 | 219 | "}\n",
|
217 | 220 | "\n",
|
218 | 221 | "config[\"head_size\"] = config[\"embedding_dim\"] // config[\"heads_num\"]"
|
219 | 222 | ]
|
220 | 223 | },
|
221 | 224 | {
|
222 | 225 | "cell_type": "code",
|
223 |
| - "execution_count": 11, |
| 226 | + "execution_count": null, |
224 | 227 | "metadata": {},
|
225 | 228 | "outputs": [],
|
226 | 229 | "source": [
|
227 | 230 | "class AttentionHead(nn.Module):\n",
|
228 |
| - " def __init__(self, config):\n", |
229 |
| - " super().__init__()\n", |
230 |
| - " self.Q_weights = nn.Linear(config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"])\n", |
231 |
| - " self.K_weights = nn.Linear(config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"])\n", |
232 |
| - " self.V_weights = nn.Linear(config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"])\n", |
| 231 | + " def __init__(self, config):\n", |
| 232 | + " super().__init__()\n", |
| 233 | + " self.Q_weights = nn.Linear(\n", |
| 234 | + " config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"]\n", |
| 235 | + " )\n", |
| 236 | + " self.K_weights = nn.Linear(\n", |
| 237 | + " config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"]\n", |
| 238 | + " )\n", |
| 239 | + " self.V_weights = nn.Linear(\n", |
| 240 | + " config[\"embedding_dim\"], config[\"head_size\"], config[\"use_bias\"]\n", |
| 241 | + " )\n", |
233 | 242 | "\n",
|
234 |
| - " self.dropout = nn.Dropout(config[\"dropout_rate\"])\n", |
| 243 | + " self.dropout = nn.Dropout(config[\"dropout_rate\"])\n", |
235 | 244 | "\n",
|
236 |
| - " casual_attention_mask = torch.tril(torch.ones(config[\"context_size\"], config[\"context_size\"]))\n", |
237 |
| - " self.register_buffer('casual_attention_mask', casual_attention_mask)\n", |
| 245 | + " casual_attention_mask = torch.tril(\n", |
| 246 | + " torch.ones(config[\"context_size\"], config[\"context_size\"])\n", |
| 247 | + " )\n", |
| 248 | + " self.register_buffer(\"casual_attention_mask\", casual_attention_mask)\n", |
238 | 249 | "\n",
|
| 250 | + " def forward(self, input): # (B, C, embedding_dim)\n", |
| 251 | + " batch_size, tokens_num, embedding_dim = input.shape\n", |
| 252 | + " Q = self.Q_weights(input) # (B, C, head_size)\n", |
| 253 | + " K = self.K_weights(input) # (B, C, head_size)\n", |
| 254 | + " V = self.V_weights(input) # (B, C, head_size)\n", |
239 | 255 | "\n",
|
240 |
| - " def forward(self, input): # (B, C, embedding_dim)\n", |
241 |
| - " batch_size, tokens_num, embedding_dim = input.shape\n", |
242 |
| - " Q = self.Q_weights(input) # (B, C, head_size)\n", |
243 |
| - " K = self.K_weights(input) # (B, C, head_size)\n", |
244 |
| - " V = self.V_weights(input) # (B, C, head_size)\n", |
| 256 | + " attention_scores = Q @ K.transpose(1, 2) # (B, C, C)\n", |
| 257 | + " attention_scores = attention_scores.masked_fill(\n", |
| 258 | + " self.casual_attention_mask[:tokens_num, :tokens_num] == 0, -torch.inf\n", |
| 259 | + " )\n", |
| 260 | + " attention_scores = attention_scores / (K.shape[-1] ** 0.5)\n", |
| 261 | + " attention_scores = torch.softmax(attention_scores, dim=-1)\n", |
| 262 | + " attention_scores = self.dropout(attention_scores)\n", |
245 | 263 | "\n",
|
246 |
| - " attention_scores = Q @ K.transpose(1, 2) # (B, C, C)\n", |
247 |
| - " attention_scores = attention_scores.masked_fill(\n", |
248 |
| - " self.casual_attention_mask[:tokens_num,:tokens_num] == 0,\n", |
249 |
| - " -torch.inf\n", |
250 |
| - " )\n", |
251 |
| - " attention_scores = attention_scores / ( K.shape[-1] ** 0.5 )\n", |
252 |
| - " attention_scores = torch.softmax(attention_scores, dim=-1)\n", |
253 |
| - " attention_scores = self.dropout(attention_scores)\n", |
254 |
| - "\n", |
255 |
| - " return attention_scores @ V # (B, C, head_size)" |
| 264 | + " return attention_scores @ V # (B, C, head_size)" |
256 | 265 | ]
|
257 | 266 | },
|
258 | 267 | {
|
|
304 | 313 | },
|
305 | 314 | {
|
306 | 315 | "cell_type": "code",
|
307 |
| - "execution_count": 16, |
| 316 | + "execution_count": null, |
308 | 317 | "metadata": {},
|
309 | 318 | "outputs": [],
|
310 | 319 | "source": [
|
311 | 320 | "class MultiHeadAttention(nn.Module):\n",
|
312 |
| - " def __init__(self, config):\n", |
313 |
| - " super().__init__()\n", |
| 321 | + " def __init__(self, config):\n", |
| 322 | + " super().__init__()\n", |
314 | 323 | "\n",
|
315 |
| - " heads_list = [AttentionHead(config) for _ in range(config[\"heads_num\"])]\n", |
316 |
| - " self.heads = nn.ModuleList(heads_list)\n", |
| 324 | + " heads_list = [AttentionHead(config) for _ in range(config[\"heads_num\"])]\n", |
| 325 | + " self.heads = nn.ModuleList(heads_list)\n", |
317 | 326 | "\n",
|
318 |
| - " self.linear = nn.Linear(config[\"embedding_dim\"], config[\"embedding_dim\"])\n", |
319 |
| - " self.dropout = nn.Dropout(config[\"dropout_rate\"])\n", |
| 327 | + " self.linear = nn.Linear(config[\"embedding_dim\"], config[\"embedding_dim\"])\n", |
| 328 | + " self.dropout = nn.Dropout(config[\"dropout_rate\"])\n", |
320 | 329 | "\n",
|
321 |
| - " def forward(self, input):\n", |
322 |
| - " heads_outputs = [head(input) for head in self.heads]\n", |
| 330 | + " def forward(self, input):\n", |
| 331 | + " heads_outputs = [head(input) for head in self.heads]\n", |
323 | 332 | "\n",
|
324 |
| - " scores_change = torch.cat(heads_outputs, dim=-1)\n", |
| 333 | + " scores_change = torch.cat(heads_outputs, dim=-1)\n", |
325 | 334 | "\n",
|
326 |
| - " scores_change = self.linear(scores_change)\n", |
327 |
| - " return self.dropout(scores_change)" |
| 335 | + " scores_change = self.linear(scores_change)\n", |
| 336 | + " return self.dropout(scores_change)" |
328 | 337 | ]
|
329 | 338 | },
|
330 | 339 | {
|
|
0 commit comments