Skip to content

Commit c110fab

Browse files
artitwlukaszkaiser
authored andcommitted
New problem: mathematical language understanding (tensorflow#1290)
* fix bAbi data generator and readme * Fix bAbi hparams deletion * Fix bAbi hparams delete unecessary keys * Fix bAbi hparams clean keys * bAbi hparams delete keys * fix readme * fix universal transformer decoding * fix merge conflict * mathematical language understanding * clarify usage * add to authors
1 parent 111466d commit c110fab

File tree

6 files changed

+129
-5
lines changed

6 files changed

+129
-5
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
# of contributors, see the revision history in source control.
66

77
Google Inc.
8+
Artit Wangperawong

README.md

+20
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pip install tensor2tensor && t2t-trainer \
4747
### Contents
4848

4949
* [Suggested Datasets and Models](#suggested-datasets-and-models)
50+
* [Mathematical Language Understanding](#mathematical-language-understanding)
5051
* [Story, Question and Answer](#story-question-and-answer)
5152
* [Image Classification](#image-classification)
5253
* [Image Generation](#image-generation)
@@ -79,6 +80,24 @@ hyperparameters that we know works well in our setup. We usually
7980
run either on Cloud TPUs or on 8-GPU machines; you might need
8081
to modify the hyperparameters if you run on a different setup.
8182

83+
### Mathematical Language Understanding
84+
85+
For evaluating mathematical expressions at the character level involving addition, subtraction and multiplication of both positive and negative decimal numbers with variable digits assigned to symbolic variables, use
86+
87+
* the [MLU](https://art.wangperawong.com/mathematical_language_understanding_train.tar.gz) data-set:
88+
`--problem=mathematical_language_understanding`
89+
90+
You can try solving the problem with different transformer models and hyperparameters as described in the [paper](https://arxiv.org/abs/1812.02825):
91+
* Standard transformer:
92+
`--model=transformer`
93+
`--hparams_set=transformer_tiny`
94+
* Universal transformer:
95+
`--model=universal_transformer`
96+
`--hparams_set=universal_transformer_tiny`
97+
* Adaptive universal transformer:
98+
`--model=universal_transformer`
99+
`--hparams_set=adaptive_universal_transformer_tiny`
100+
82101
### Story, Question and Answer
83102

84103
For answering questions based on a story, use
@@ -464,5 +483,6 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research
464483
* [Fast Decoding in Sequence Models using Discrete Latent Variables](https://arxiv.org/abs/1803.03382)
465484
* [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
466485
* [Universal Transformers](https://arxiv.org/abs/1807.03819)
486+
* [Attending to Mathematical Language with Transformers](https://arxiv.org/abs/1812.02825)
467487

468488
*Note: This is not an official Google product.*

tensor2tensor/data_generators/all_problems.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"tensor2tensor.data_generators.lm1b",
5151
"tensor2tensor.data_generators.lm1b_imdb",
5252
"tensor2tensor.data_generators.lm1b_mnli",
53+
"tensor2tensor.data_generators.mathematical_language_understanding",
5354
"tensor2tensor.data_generators.mnist",
5455
"tensor2tensor.data_generators.mrpc",
5556
"tensor2tensor.data_generators.mscoco",

tensor2tensor/data_generators/babi_qa.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ def _prepare_babi_data(tmp_dir, data_dir):
109109
tf.gfile.MakeDirs(data_dir)
110110

111111
file_path = os.path.join(tmp_dir, _TAR)
112-
headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36"} # pylint: disable=line-too-long
112+
headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'}
113113
resp = requests.get(_URL, headers=headers)
114-
with open(file_path, "wb") as f:
114+
with open(file_path, 'wb') as f:
115115
f.write(resp.content)
116116

117117
tar = tarfile.open(file_path)
@@ -459,7 +459,6 @@ def hparams(self, defaults, unused_model_hparams):
459459
if "context" in p.vocab_size:
460460
del p.vocab_size["context"]
461461

462-
463462
def _problems_to_register():
464463
"""Problems for which we want to create datasets.
465464
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# coding=utf-8
2+
# Copyright 2018 Artit Wangperawong [email protected]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
r"""Data generators for the Mathematical Language Understanding dataset.
17+
18+
The training and test data were generated by assigning symbolic variables
19+
either positive or negative decimal integers and then describing the algebraic
20+
operation to perform. We restrict our variable assignments to the range
21+
x,y->[-1000,1000) and the operations to the set {+,-,*}. To ensure that the
22+
model embraces symbolic variables, the order in which x and y appears in the
23+
expression is randomly chosen. For instance, an input string contrasting from
24+
the example shown above might be y=129,x=531,x-y. Each input string is
25+
accompanied by its target string, which is the evaluation of the mathematical
26+
expression. For this study, all targets considered are decimal integers
27+
represented at the character level. About 12 million unique samples were thus
28+
generated and randomly split into training and test sets at an approximate
29+
ratio of 9:1, respectively.
30+
31+
For more information check the following paper:
32+
Artit Wangperawong. Attending to Mathematical Language with Transformers,
33+
arXiv:1812.02825.
34+
Available at: https://arxiv.org/abs/1812.02825
35+
36+
"""
37+
38+
from __future__ import absolute_import
39+
from __future__ import division
40+
from __future__ import print_function
41+
42+
import os
43+
44+
from tensor2tensor.data_generators import generator_utils
45+
from tensor2tensor.data_generators import problem
46+
from tensor2tensor.data_generators import text_problems
47+
from tensor2tensor.utils import registry
48+
49+
import tensorflow as tf
50+
51+
@registry.register_problem
52+
class MathematicalLanguageUnderstanding(text_problems.Text2TextProblem):
53+
URL = "https://art.wangperawong.com/mathematical_language_understanding_train.tar.gz"
54+
55+
@property
56+
def vocab_type(self):
57+
return text_problems.VocabType.CHARACTER
58+
59+
@property
60+
def dataset_splits(self):
61+
return [{
62+
"split": problem.DatasetSplit.TRAIN,
63+
"shards": 10,
64+
}, {
65+
"split": problem.DatasetSplit.EVAL,
66+
"shards": 1,
67+
}]
68+
69+
@property
70+
def is_generate_per_split(self):
71+
return False
72+
73+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
74+
"""Downloads and extracts the dataset and generates examples
75+
76+
Args:
77+
tmp_dir: temp directory to download and extract the dataset
78+
data_dir: The base directory where data and vocab files are stored.
79+
80+
Returns:
81+
data generator
82+
"""
83+
84+
if not tf.gfile.Exists(tmp_dir):
85+
tf.gfile.MakeDirs(tmp_dir)
86+
87+
if not tf.gfile.Exists(data_dir):
88+
tf.gfile.MakeDirs(data_dir)
89+
90+
# Download and extract
91+
compressed_filename = os.path.basename(self.URL)
92+
download_path = generator_utils.maybe_download(tmp_dir, compressed_filename,
93+
self.URL)
94+
95+
with tarfile.open(download_path, "r:gz") as tar:
96+
tar.extractall(tmp_dir)
97+
98+
filepath = os.path.join(tmp_dir, "mathematical_language_understanding_train.txt")
99+
100+
with open(filepath, 'r') as fp:
101+
for l in fp:
102+
prob, ans = l.strip().split(':')
103+
yield {"inputs": prob, "targets": ans}
104+

tensor2tensor/models/research/universal_transformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ def _greedy_infer(self, features, decode_length, use_tpu=False):
243243
return (self._slow_greedy_infer_tpu(features, decode_length) if use_tpu else
244244
self._slow_greedy_infer(features, decode_length))
245245

246-
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha,
247-
use_tpu=False):
246+
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha, use_tpu=False):
248247
"""Beam search decoding.
249248
250249
Args:

0 commit comments

Comments
 (0)