|
108 | 108 | "source": [
|
109 | 109 | "FLAGS.data_dir = \"tfds://imdb_reviews/plain_text\"\n",
|
110 | 110 | "FLAGS.attention_type = \"block_sparse\"\n",
|
111 |
| - "FLAGS.max_encoder_length = 3072 # 4096 on 16GB GPUs like V100, on free colab only lower memory GPU like T4 is available\n", |
| 111 | + "FLAGS.max_encoder_length = 4096 # reduce for quicker demo on free colab\n", |
112 | 112 | "FLAGS.learning_rate = 1e-5\n",
|
113 |
| - "FLAGS.num_train_steps = 10000\n", |
| 113 | + "FLAGS.num_train_steps = 2000\n", |
114 | 114 | "FLAGS.attention_probs_dropout_prob = 0.0\n",
|
115 | 115 | "FLAGS.hidden_dropout_prob = 0.0\n",
|
| 116 | + "FLAGS.use_gradient_checkpointing = True\n", |
116 | 117 | "FLAGS.vocab_model_file = \"gpt2\""
|
117 | 118 | ]
|
118 | 119 | },
|
|
146 | 147 | "source": [
|
147 | 148 | "model = modeling.BertModel(bert_config)\n",
|
148 | 149 | "headl = run_classifier.ClassifierLossLayer(\n",
|
149 |
| - " bert_config[\"num_labels\"], bert_config[\"hidden_dropout_prob\"],\n", |
| 150 | + " bert_config[\"hidden_size\"], bert_config[\"num_labels\"],\n", |
| 151 | + " bert_config[\"hidden_dropout_prob\"],\n", |
150 | 152 | " utils.create_initializer(bert_config[\"initializer_range\"]),\n",
|
151 | 153 | " name=bert_config[\"scope\"]+\"/classifier\")"
|
152 | 154 | ]
|
|
211 | 213 | " max_encoder_length=FLAGS.max_encoder_length,\n",
|
212 | 214 | " substitute_newline=FLAGS.substitute_newline,\n",
|
213 | 215 | " is_training=True)\n",
|
214 |
| - "dataset = train_input_fn({'batch_size': 2})" |
| 216 | + "dataset = train_input_fn({'batch_size': 8})" |
215 | 217 | ]
|
216 | 218 | },
|
217 | 219 | {
|
|
237 | 239 | "name": "stdout",
|
238 | 240 | "output_type": "stream",
|
239 | 241 | "text": [
|
240 |
| - "(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n", |
| 242 | + "(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n", |
241 | 243 | "array([[ 65, 733, 474, ..., 0, 0, 0],\n",
|
242 |
| - " [ 65, 415, 26500, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)\u003e)\n", |
243 |
| - "(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n", |
244 |
| - "array([[ 65, 484, 20677, ..., 0, 0, 0],\n", |
245 |
| - " [ 65, 871, 3908, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)\u003e)\n", |
246 |
| - "(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n", |
247 |
| - "array([[ 65, 415, 6506, ..., 0, 0, 0],\n", |
248 |
| - " [ 65, 418, 1150, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 0], dtype=int32)\u003e)\n" |
| 244 | + " [ 65, 415, 26500, ..., 0, 0, 0],\n", |
| 245 | + " [ 65, 484, 20677, ..., 0, 0, 0],\n", |
| 246 | + " ...,\n", |
| 247 | + " [ 65, 418, 1150, ..., 0, 0, 0],\n", |
| 248 | + " [ 65, 9271, 5714, ..., 0, 0, 0],\n", |
| 249 | + " [ 65, 8301, 113, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 1, 1, 1, 0, 1, 0], dtype=int32)\u003e)\n", |
| 250 | + "(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n", |
| 251 | + "array([[ 65, 1182, 358, ..., 0, 0, 0],\n", |
| 252 | + " [ 65, 871, 419, ..., 0, 0, 0],\n", |
| 253 | + " [ 65, 415, 1908, ..., 0, 0, 0],\n", |
| 254 | + " ...,\n", |
| 255 | + " [ 65, 484, 1722, ..., 0, 0, 0],\n", |
| 256 | + " [ 65, 876, 1154, ..., 0, 0, 0],\n", |
| 257 | + " [ 65, 415, 1092, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 0, 0, 1], dtype=int32)\u003e)\n", |
| 258 | + "(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n", |
| 259 | + "array([[ 65, 456, 382, ..., 0, 0, 0],\n", |
| 260 | + " [ 65, 484, 34679, ..., 0, 0, 0],\n", |
| 261 | + " [ 65, 16224, 112, ..., 0, 0, 0],\n", |
| 262 | + " ...,\n", |
| 263 | + " [ 65, 484, 3822, ..., 0, 0, 0],\n", |
| 264 | + " [ 65, 484, 2747, ..., 0, 0, 0],\n", |
| 265 | + " [ 65, 415, 1208, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 0, 0, 0, 1, 0, 1, 0], dtype=int32)\u003e)\n" |
249 | 266 | ]
|
250 | 267 | }
|
251 | 268 | ],
|
|
261 | 278 | "id": "lYCyGH56zOOU"
|
262 | 279 | },
|
263 | 280 | "source": [
|
264 |
| - "## Check outputs" |
| 281 | + "## (Optionally) Check outputs" |
265 | 282 | ]
|
266 | 283 | },
|
267 | 284 | {
|
|
385 | 402 | "name": "stderr",
|
386 | 403 | "output_type": "stream",
|
387 | 404 | "text": [
|
388 |
| - " 0%| | 0/10000 [00:06\u003c1:32:59, 1.79it/s]" |
| 405 | + " 0%| | 0/2000 [00:06\u003c1:59:12, 3.57it/s]" |
389 | 406 | ]
|
390 | 407 | },
|
391 | 408 | {
|
392 | 409 | "name": "stdout",
|
393 | 410 | "output_type": "stream",
|
394 | 411 | "text": [
|
395 | 412 | "\n",
|
396 |
| - "Loss = 0.4131925702095032 Accuracy = 0.8123108148574829" |
| 413 | + "Loss = 0.47779741883277893 Accuracy = 0.7558900713920593" |
397 | 414 | ]
|
398 | 415 | },
|
399 | 416 | {
|
400 | 417 | "name": "stderr",
|
401 | 418 | "output_type": "stream",
|
402 | 419 | "text": [
|
403 |
| - " 10%|█ | 1000/10000 [08:26\u003c1:16:08, 1.97it/s]" |
| 420 | + " 10%|█ | 200/2000 [11:26\u003c1:48:08, 3.60it/s]" |
404 | 421 | ]
|
405 | 422 | },
|
406 | 423 | {
|
407 | 424 | "name": "stdout",
|
408 | 425 | "output_type": "stream",
|
409 | 426 | "text": [
|
410 | 427 | "\n",
|
411 |
| - "Loss = 0.32566359639167786 Accuracy = 0.8608739376068115" |
| 428 | + "Loss = 0.3703668415546417 Accuracy = 0.8318414092063904" |
412 | 429 | ]
|
413 | 430 | },
|
414 | 431 | {
|
415 | 432 | "name": "stderr",
|
416 | 433 | "output_type": "stream",
|
417 | 434 | "text": [
|
418 |
| - " 20%|██ | 2000/10000 [16:52\u003c1:08:17, 1.95it/s]" |
| 435 | + " 20%|██ | 400/2000 [23:52\u003c1:35:17, 3.58it/s]" |
419 | 436 | ]
|
420 | 437 | },
|
421 | 438 | {
|
422 | 439 | "name": "stdout",
|
423 | 440 | "output_type": "stream",
|
424 | 441 | "text": [
|
425 | 442 | "\n",
|
426 |
| - "Loss = 0.28784531354904175 Accuracy = 0.882480800151825" |
| 443 | + "Loss = 0.3130376636981964 Accuracy = 0.8654822111129761" |
427 | 444 | ]
|
428 | 445 | },
|
429 | 446 | {
|
430 | 447 | "name": "stderr",
|
431 | 448 | "output_type": "stream",
|
432 | 449 | "text": [
|
433 |
| - " 30%|███ | 3000/10000 [25:18\u003c58:58, 1.98it/s]" |
| 450 | + " 30%|███ | 600/2000 [35:18\u003c1:24:58, 3.60it/s]" |
434 | 451 | ]
|
435 | 452 | },
|
436 | 453 | {
|
437 | 454 | "name": "stdout",
|
438 | 455 | "output_type": "stream",
|
439 | 456 | "text": [
|
440 | 457 | "\n",
|
441 |
| - "Loss = 0.2657429575920105 Accuracy = 0.8936356902122498" |
| 458 | + "Loss = 0.2806303799152374 Accuracy = 0.8822692632675171" |
442 | 459 | ]
|
443 | 460 | },
|
444 | 461 | {
|
445 | 462 | "name": "stderr",
|
446 | 463 | "output_type": "stream",
|
447 | 464 | "text": [
|
448 |
| - " 40%|████ | 4000/10000 [33:44\u003c50:41, 1.97it/s]" |
| 465 | + " 40%|████ | 800/2000 [47:44\u003c1:12:41, 3.60it/s]" |
449 | 466 | ]
|
450 | 467 | },
|
451 | 468 | {
|
452 | 469 | "name": "stdout",
|
453 | 470 | "output_type": "stream",
|
454 | 471 | "text": [
|
455 | 472 | "\n",
|
456 |
| - "Loss = 0.24971100687980652 Accuracy = 0.9020236134529114" |
| 473 | + "Loss = 0.2649693191051483 Accuracy = 0.8901362419128418" |
457 | 474 | ]
|
458 | 475 | },
|
459 | 476 | {
|
460 | 477 | "name": "stderr",
|
461 | 478 | "output_type": "stream",
|
462 | 479 | "text": [
|
463 |
| - " 50%|█████ | 5000/10000 [42:10\u003c42:03, 1.98it/s]" |
| 480 | + " 50%|█████ | 1000/2000 [59:10\u003c59:03, 3.58it/s]" |
464 | 481 | ]
|
465 | 482 | },
|
466 | 483 | {
|
467 | 484 | "name": "stdout",
|
468 | 485 | "output_type": "stream",
|
469 | 486 | "text": [
|
470 | 487 | "\n",
|
471 |
| - "Loss = 0.23958759009838104 Accuracy = 0.9069437384605408" |
| 488 | + "Loss = 0.25240564346313477 Accuracy = 0.8967254161834717" |
472 | 489 | ]
|
473 | 490 | },
|
474 | 491 | {
|
475 | 492 | "name": "stderr",
|
476 | 493 | "output_type": "stream",
|
477 | 494 | "text": [
|
478 |
| - " 60%|██████ | 6000/10000 [50:36\u003c33:43, 1.98it/s]" |
| 495 | + " 60%|██████ | 1200/2000 [1:11:36\u003c47:43, 3.60it/s]" |
479 | 496 | ]
|
480 | 497 | },
|
481 | 498 | {
|
482 | 499 | "name": "stdout",
|
483 | 500 | "output_type": "stream",
|
484 | 501 | "text": [
|
485 | 502 | "\n",
|
486 |
| - "Loss = 0.2304597944021225 Accuracy = 0.9108854532241821" |
| 503 | + "Loss = 0.24363534152507782 Accuracy = 0.901509702205658" |
487 | 504 | ]
|
488 | 505 | },
|
489 | 506 | {
|
490 | 507 | "name": "stderr",
|
491 | 508 | "output_type": "stream",
|
492 | 509 | "text": [
|
493 |
| - " 70%|███████ | 7000/10000 [59:02\u003c25:20, 1.97it/s]" |
| 510 | + " 70%|███████ | 1400/2000 [1:23:02\u003c35:20, 3.58it/s]" |
494 | 511 | ]
|
495 | 512 | },
|
496 | 513 | {
|
497 | 514 | "name": "stdout",
|
498 | 515 | "output_type": "stream",
|
499 | 516 | "text": [
|
500 | 517 | "\n",
|
501 |
| - "Loss = 0.2243848443031311 Accuracy = 0.9135903120040894" |
| 518 | + "Loss = 0.23414449393749237 Accuracy = 0.9062696695327759" |
502 | 519 | ]
|
503 | 520 | },
|
504 | 521 | {
|
505 | 522 | "name": "stderr",
|
506 | 523 | "output_type": "stream",
|
507 | 524 | "text": [
|
508 |
| - " 80%|████████ | 8000/10000 [1:07:30\u003c17:23, 1.92it/s]" |
| 525 | + " 80%|████████ | 1600/2000 [1:35:30\u003c23:23, 3.60it/s]" |
509 | 526 | ]
|
510 | 527 | },
|
511 | 528 | {
|
512 | 529 | "name": "stdout",
|
513 | 530 | "output_type": "stream",
|
514 | 531 | "text": [
|
515 | 532 | "\n",
|
516 |
| - "Loss = 0.21911397576332092 Accuracy = 0.9155822396278381" |
| 533 | + "Loss = 0.22541514039039612 Accuracy = 0.9101060628890991" |
517 | 534 | ]
|
518 | 535 | },
|
519 | 536 | {
|
520 | 537 | "name": "stderr",
|
521 | 538 | "output_type": "stream",
|
522 | 539 | "text": [
|
523 |
| - " 90%|█████████ | 9000/10000 [1:16:05\u003c08:34, 1.94it/s]" |
| 540 | + " 90%|█████████ | 1800/2000 [1:46:05\u003c11:34, 3.60it/s]" |
524 | 541 | ]
|
525 | 542 | },
|
526 | 543 | {
|
527 | 544 | "name": "stdout",
|
528 | 545 | "output_type": "stream",
|
529 | 546 | "text": [
|
530 | 547 | "\n",
|
531 |
| - "Loss = 0.21378542482852936 Accuracy = 0.9180262088775635" |
| 548 | + "Loss = 0.2210962176322937 Accuracy = 0.9125439524650574" |
532 | 549 | ]
|
533 | 550 | },
|
534 | 551 | {
|
535 | 552 | "name": "stderr",
|
536 | 553 | "output_type": "stream",
|
537 | 554 | "text": [
|
538 |
| - "100%|██████████| 10000/10000 [1:24:39\u003c00:00, 1.94it/s]" |
| 555 | + "100%|██████████| 2000/2000 [1:59:39\u003c00:00, 3.58it/s]" |
539 | 556 | ]
|
540 | 557 | },
|
541 | 558 | {
|
|
605 | 622 | " max_encoder_length=FLAGS.max_encoder_length,\n",
|
606 | 623 | " substitute_newline=FLAGS.substitute_newline,\n",
|
607 | 624 | " is_training=False)\n",
|
608 |
| - "eval_dataset = eval_input_fn({'batch_size': 2})" |
| 625 | + "eval_dataset = eval_input_fn({'batch_size': 8})" |
609 | 626 | ]
|
610 | 627 | },
|
611 | 628 | {
|
|
0 commit comments