Skip to content

Commit 5ede318

Browse files
BigBird Teammanzilz
BigBird Team
authored andcommitted
Project import generated by Copybara.
PiperOrigin-RevId: 369239135
1 parent ad71331 commit 5ede318

21 files changed

+1592
-967
lines changed

README.md

-1
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,3 @@ is no benefit in using sparse BigBird attention.
161161
Recently, [Long Range Arena](https://arxiv.org/pdf/2011.04006.pdf) provided a benchmark of six tasks that require longer context, and performed experiments to benchmark all existing long range transformers. The results are shown below. BigBird model, unlike its counterparts, clearly reduces memory consumption without sacrificing performance.
162162

163163
<img src="https://github.com/google-research/bigbird/blob/master/comparison.png" width="50%">
164-

bigbird/classifier/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 The BigBird Authors.
1+
# Copyright 2021 The BigBird Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

bigbird/classifier/imdb.ipynb

+52-35
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,12 @@
108108
"source": [
109109
"FLAGS.data_dir = \"tfds://imdb_reviews/plain_text\"\n",
110110
"FLAGS.attention_type = \"block_sparse\"\n",
111-
"FLAGS.max_encoder_length = 3072 # 4096 on 16GB GPUs like V100, on free colab only lower memory GPU like T4 is available\n",
111+
"FLAGS.max_encoder_length = 4096 # reduce for quicker demo on free colab\n",
112112
"FLAGS.learning_rate = 1e-5\n",
113-
"FLAGS.num_train_steps = 10000\n",
113+
"FLAGS.num_train_steps = 2000\n",
114114
"FLAGS.attention_probs_dropout_prob = 0.0\n",
115115
"FLAGS.hidden_dropout_prob = 0.0\n",
116+
"FLAGS.use_gradient_checkpointing = True\n",
116117
"FLAGS.vocab_model_file = \"gpt2\""
117118
]
118119
},
@@ -146,7 +147,8 @@
146147
"source": [
147148
"model = modeling.BertModel(bert_config)\n",
148149
"headl = run_classifier.ClassifierLossLayer(\n",
149-
" bert_config[\"num_labels\"], bert_config[\"hidden_dropout_prob\"],\n",
150+
" bert_config[\"hidden_size\"], bert_config[\"num_labels\"],\n",
151+
" bert_config[\"hidden_dropout_prob\"],\n",
150152
" utils.create_initializer(bert_config[\"initializer_range\"]),\n",
151153
" name=bert_config[\"scope\"]+\"/classifier\")"
152154
]
@@ -211,7 +213,7 @@
211213
" max_encoder_length=FLAGS.max_encoder_length,\n",
212214
" substitute_newline=FLAGS.substitute_newline,\n",
213215
" is_training=True)\n",
214-
"dataset = train_input_fn({'batch_size': 2})"
216+
"dataset = train_input_fn({'batch_size': 8})"
215217
]
216218
},
217219
{
@@ -237,15 +239,30 @@
237239
"name": "stdout",
238240
"output_type": "stream",
239241
"text": [
240-
"(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n",
242+
"(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n",
241243
"array([[ 65, 733, 474, ..., 0, 0, 0],\n",
242-
" [ 65, 415, 26500, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)\u003e)\n",
243-
"(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n",
244-
"array([[ 65, 484, 20677, ..., 0, 0, 0],\n",
245-
" [ 65, 871, 3908, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)\u003e)\n",
246-
"(\u003ctf.Tensor: shape=(2, 4096), dtype=int32, numpy=\n",
247-
"array([[ 65, 415, 6506, ..., 0, 0, 0],\n",
248-
" [ 65, 418, 1150, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 0], dtype=int32)\u003e)\n"
244+
" [ 65, 415, 26500, ..., 0, 0, 0],\n",
245+
" [ 65, 484, 20677, ..., 0, 0, 0],\n",
246+
" ...,\n",
247+
" [ 65, 418, 1150, ..., 0, 0, 0],\n",
248+
" [ 65, 9271, 5714, ..., 0, 0, 0],\n",
249+
" [ 65, 8301, 113, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 1, 1, 1, 0, 1, 0], dtype=int32)\u003e)\n",
250+
"(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n",
251+
"array([[ 65, 1182, 358, ..., 0, 0, 0],\n",
252+
" [ 65, 871, 419, ..., 0, 0, 0],\n",
253+
" [ 65, 415, 1908, ..., 0, 0, 0],\n",
254+
" ...,\n",
255+
" [ 65, 484, 1722, ..., 0, 0, 0],\n",
256+
" [ 65, 876, 1154, ..., 0, 0, 0],\n",
257+
" [ 65, 415, 1092, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 0, 0, 1], dtype=int32)\u003e)\n",
258+
"(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n",
259+
"array([[ 65, 456, 382, ..., 0, 0, 0],\n",
260+
" [ 65, 484, 34679, ..., 0, 0, 0],\n",
261+
" [ 65, 16224, 112, ..., 0, 0, 0],\n",
262+
" ...,\n",
263+
" [ 65, 484, 3822, ..., 0, 0, 0],\n",
264+
" [ 65, 484, 2747, ..., 0, 0, 0],\n",
265+
" [ 65, 415, 1208, ..., 0, 0, 0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 0, 0, 0, 1, 0, 1, 0], dtype=int32)\u003e)\n"
249266
]
250267
}
251268
],
@@ -261,7 +278,7 @@
261278
"id": "lYCyGH56zOOU"
262279
},
263280
"source": [
264-
"## Check outputs"
281+
"## (Optionally) Check outputs"
265282
]
266283
},
267284
{
@@ -385,157 +402,157 @@
385402
"name": "stderr",
386403
"output_type": "stream",
387404
"text": [
388-
" 0%| | 0/10000 [00:06\u003c1:32:59, 1.79it/s]"
405+
" 0%| | 0/2000 [00:06\u003c1:59:12, 3.57it/s]"
389406
]
390407
},
391408
{
392409
"name": "stdout",
393410
"output_type": "stream",
394411
"text": [
395412
"\n",
396-
"Loss = 0.4131925702095032 Accuracy = 0.8123108148574829"
413+
"Loss = 0.47779741883277893 Accuracy = 0.7558900713920593"
397414
]
398415
},
399416
{
400417
"name": "stderr",
401418
"output_type": "stream",
402419
"text": [
403-
" 10%|█ | 1000/10000 [08:26\u003c1:16:08, 1.97it/s]"
420+
" 10%|█ | 200/2000 [11:26\u003c1:48:08, 3.60it/s]"
404421
]
405422
},
406423
{
407424
"name": "stdout",
408425
"output_type": "stream",
409426
"text": [
410427
"\n",
411-
"Loss = 0.32566359639167786 Accuracy = 0.8608739376068115"
428+
"Loss = 0.3703668415546417 Accuracy = 0.8318414092063904"
412429
]
413430
},
414431
{
415432
"name": "stderr",
416433
"output_type": "stream",
417434
"text": [
418-
" 20%|██ | 2000/10000 [16:52\u003c1:08:17, 1.95it/s]"
435+
" 20%|██ | 400/2000 [23:52\u003c1:35:17, 3.58it/s]"
419436
]
420437
},
421438
{
422439
"name": "stdout",
423440
"output_type": "stream",
424441
"text": [
425442
"\n",
426-
"Loss = 0.28784531354904175 Accuracy = 0.882480800151825"
443+
"Loss = 0.3130376636981964 Accuracy = 0.8654822111129761"
427444
]
428445
},
429446
{
430447
"name": "stderr",
431448
"output_type": "stream",
432449
"text": [
433-
" 30%|███ | 3000/10000 [25:18\u003c58:58, 1.98it/s]"
450+
" 30%|███ | 600/2000 [35:18\u003c1:24:58, 3.60it/s]"
434451
]
435452
},
436453
{
437454
"name": "stdout",
438455
"output_type": "stream",
439456
"text": [
440457
"\n",
441-
"Loss = 0.2657429575920105 Accuracy = 0.8936356902122498"
458+
"Loss = 0.2806303799152374 Accuracy = 0.8822692632675171"
442459
]
443460
},
444461
{
445462
"name": "stderr",
446463
"output_type": "stream",
447464
"text": [
448-
" 40%|████ | 4000/10000 [33:44\u003c50:41, 1.97it/s]"
465+
" 40%|████ | 800/2000 [47:44\u003c1:12:41, 3.60it/s]"
449466
]
450467
},
451468
{
452469
"name": "stdout",
453470
"output_type": "stream",
454471
"text": [
455472
"\n",
456-
"Loss = 0.24971100687980652 Accuracy = 0.9020236134529114"
473+
"Loss = 0.2649693191051483 Accuracy = 0.8901362419128418"
457474
]
458475
},
459476
{
460477
"name": "stderr",
461478
"output_type": "stream",
462479
"text": [
463-
" 50%|█████ | 5000/10000 [42:10\u003c42:03, 1.98it/s]"
480+
" 50%|█████ | 1000/2000 [59:10\u003c59:03, 3.58it/s]"
464481
]
465482
},
466483
{
467484
"name": "stdout",
468485
"output_type": "stream",
469486
"text": [
470487
"\n",
471-
"Loss = 0.23958759009838104 Accuracy = 0.9069437384605408"
488+
"Loss = 0.25240564346313477 Accuracy = 0.8967254161834717"
472489
]
473490
},
474491
{
475492
"name": "stderr",
476493
"output_type": "stream",
477494
"text": [
478-
" 60%|██████ | 6000/10000 [50:36\u003c33:43, 1.98it/s]"
495+
" 60%|██████ | 1200/2000 [1:11:36\u003c47:43, 3.60it/s]"
479496
]
480497
},
481498
{
482499
"name": "stdout",
483500
"output_type": "stream",
484501
"text": [
485502
"\n",
486-
"Loss = 0.2304597944021225 Accuracy = 0.9108854532241821"
503+
"Loss = 0.24363534152507782 Accuracy = 0.901509702205658"
487504
]
488505
},
489506
{
490507
"name": "stderr",
491508
"output_type": "stream",
492509
"text": [
493-
" 70%|███████ | 7000/10000 [59:02\u003c25:20, 1.97it/s]"
510+
" 70%|███████ | 1400/2000 [1:23:02\u003c35:20, 3.58it/s]"
494511
]
495512
},
496513
{
497514
"name": "stdout",
498515
"output_type": "stream",
499516
"text": [
500517
"\n",
501-
"Loss = 0.2243848443031311 Accuracy = 0.9135903120040894"
518+
"Loss = 0.23414449393749237 Accuracy = 0.9062696695327759"
502519
]
503520
},
504521
{
505522
"name": "stderr",
506523
"output_type": "stream",
507524
"text": [
508-
" 80%|████████ | 8000/10000 [1:07:30\u003c17:23, 1.92it/s]"
525+
" 80%|████████ | 1600/2000 [1:35:30\u003c23:23, 3.60it/s]"
509526
]
510527
},
511528
{
512529
"name": "stdout",
513530
"output_type": "stream",
514531
"text": [
515532
"\n",
516-
"Loss = 0.21911397576332092 Accuracy = 0.9155822396278381"
533+
"Loss = 0.22541514039039612 Accuracy = 0.9101060628890991"
517534
]
518535
},
519536
{
520537
"name": "stderr",
521538
"output_type": "stream",
522539
"text": [
523-
" 90%|█████████ | 9000/10000 [1:16:05\u003c08:34, 1.94it/s]"
540+
" 90%|█████████ | 1800/2000 [1:46:05\u003c11:34, 3.60it/s]"
524541
]
525542
},
526543
{
527544
"name": "stdout",
528545
"output_type": "stream",
529546
"text": [
530547
"\n",
531-
"Loss = 0.21378542482852936 Accuracy = 0.9180262088775635"
548+
"Loss = 0.2210962176322937 Accuracy = 0.9125439524650574"
532549
]
533550
},
534551
{
535552
"name": "stderr",
536553
"output_type": "stream",
537554
"text": [
538-
"100%|██████████| 10000/10000 [1:24:39\u003c00:00, 1.94it/s]"
555+
"100%|██████████| 2000/2000 [1:59:39\u003c00:00, 3.58it/s]"
539556
]
540557
},
541558
{
@@ -605,7 +622,7 @@
605622
" max_encoder_length=FLAGS.max_encoder_length,\n",
606623
" substitute_newline=FLAGS.substitute_newline,\n",
607624
" is_training=False)\n",
608-
"eval_dataset = eval_input_fn({'batch_size': 2})"
625+
"eval_dataset = eval_input_fn({'batch_size': 8})"
609626
]
610627
},
611628
{

bigbird/classifier/run_classifier.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 The BigBird Authors.
1+
# Copyright 2021 The BigBird Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -238,7 +238,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
238238

239239
model = modeling.BertModel(bert_config)
240240
headl = ClassifierLossLayer(
241-
bert_config["num_labels"], bert_config["hidden_dropout_prob"],
241+
bert_config["hidden_size"], bert_config["num_labels"],
242+
bert_config["hidden_dropout_prob"],
242243
utils.create_initializer(bert_config["initializer_range"]),
243244
name=bert_config["scope"]+"/classifier")
244245

@@ -310,44 +311,41 @@ def metric_fn(loss_value, label_ids, log_probs):
310311
return model_fn
311312

312313

313-
class ClassifierLossLayer(tf.compat.v1.layers.Layer):
314+
class ClassifierLossLayer(tf.keras.layers.Layer):
314315
"""Final classifier layer with loss."""
315316

316317
def __init__(self,
318+
hidden_size,
317319
num_labels,
318320
dropout_prob=0.0,
319321
initializer=None,
320322
use_bias=True,
321323
name="classifier"):
322324
super(ClassifierLossLayer, self).__init__(name=name)
325+
self.hidden_size = hidden_size
323326
self.num_labels = num_labels
324327
self.initializer = initializer
325-
self.dropout_prob = dropout_prob
328+
self.dropout = tf.keras.layers.Dropout(dropout_prob)
326329
self.use_bias = use_bias
327330

328-
self.w = None
329-
self.b = None
330-
331-
def call(self, input_tensor, labels=None, training=None):
332-
last_dim = utils.get_shape_list(input_tensor)[-1]
333-
input_tensor = utils.dropout(input_tensor, self.dropout_prob, training)
334-
335-
if self.w is None:
331+
with tf.compat.v1.variable_scope(name):
336332
self.w = tf.compat.v1.get_variable(
337333
name="kernel",
338-
shape=[last_dim, self.num_labels],
334+
shape=[self.hidden_size, self.num_labels],
339335
initializer=self.initializer)
340-
self.initializer = None
341-
self._trainable_weights.append(self.w)
342-
logits = tf.matmul(input_tensor, self.w)
343-
344-
if self.use_bias:
345-
if self.b is None:
336+
if self.use_bias:
346337
self.b = tf.compat.v1.get_variable(
347338
name="bias",
348339
shape=[self.num_labels],
349340
initializer=tf.zeros_initializer)
350-
self._trainable_weights.append(self.b)
341+
else:
342+
self.b = None
343+
344+
def call(self, input_tensor, labels=None, training=None):
345+
input_tensor = self.dropout(input_tensor, training)
346+
347+
logits = tf.matmul(input_tensor, self.w)
348+
if self.use_bias:
351349
logits = tf.nn.bias_add(logits, self.b)
352350

353351
log_probs = tf.nn.log_softmax(logits, axis=-1)
@@ -382,6 +380,7 @@ def main(_):
382380

383381
model_fn = model_fn_builder(bert_config)
384382
estimator = utils.get_estimator(bert_config, model_fn)
383+
tmp_data_dir = os.path.join(FLAGS.output_dir, "tfds")
385384

386385
if FLAGS.do_train:
387386
logging.info("***** Running training *****")
@@ -392,7 +391,7 @@ def main(_):
392391
vocab_model_file=FLAGS.vocab_model_file,
393392
max_encoder_length=FLAGS.max_encoder_length,
394393
substitute_newline=FLAGS.substitute_newline,
395-
tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
394+
tmp_dir=tmp_data_dir,
396395
is_training=True)
397396
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
398397

@@ -405,7 +404,7 @@ def main(_):
405404
vocab_model_file=FLAGS.vocab_model_file,
406405
max_encoder_length=FLAGS.max_encoder_length,
407406
substitute_newline=FLAGS.substitute_newline,
408-
tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
407+
tmp_dir=tmp_data_dir,
409408
is_training=False)
410409

411410
if FLAGS.use_tpu:

bigbird/core/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 The BigBird Authors.
1+
# Copyright 2021 The BigBird Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)