Skip to content

Commit 9f32554

Browse files
authored
[Model]Transformer (#186)
* change the signature of node/edge filter * upd filter * Support multi-dimension node feature in SPMV * push transformer * remove some experimental settings * stable version * hotfix * upd tutorial * upd README * merge * remove redundency * remove tqdm * several changes * Refactor * Refactor * tutorial train * fixed a bug * fixed perf issue * upd * change dir * move un-related to contrib * tutuorial code * remove redundency * upd * upd * upd * upd * improve viz * universal done * halt norm * fixed a bug * add draw graph * fixed several bugs * remove dependency on core * upd format of README * trigger * trigger * upd viz * trigger * add transformer tutorial * fix tutorial * fix readme * small fix on tutorials * url fix in readme * fixed func link * upd
1 parent 37feb47 commit 9f32554

25 files changed

+2849
-4
lines changed
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
*~
2+
data/
3+
scripts/
4+
checkpoints/
5+
log/
6+
*__pycache__*
7+
*.pdf
8+
*.tar.gz
9+
*.zip
10+
*.pyc
11+
*.lprof
12+
*.swp

examples/pytorch/transformer/.gitmodules

Whitespace-only changes.
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Transformer in DGL
2+
In this example we implement the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) and [Universal Transformer](https://arxiv.org/abs/1807.03819) with ACT in DGL.
3+
4+
The folder contains training module and inferencing module (beam decoder) for Transformer and training module for Universal Transformer
5+
6+
## Requirements
7+
8+
- PyTorch 0.4.1+
9+
- networkx
10+
- tqdm
11+
12+
## Usage
13+
14+
- For training:
15+
16+
```
17+
python translation_train.py [--gpus id1,id2,...] [--N #layers] [--dataset DATASET] [--batch BATCHSIZE] [--universal]
18+
```
19+
20+
- For evaluating BLEU score on test set(by enabling `--print` to see translated text):
21+
22+
```
23+
python translation_test.py [--gpu id] [--N #layers] [--dataset DATASET] [--batch BATCHSIZE] [--checkpoint CHECKPOINT] [--print] [--universal]
24+
```
25+
26+
Available datasets: `copy`, `sort`, `wmt14`, `multi30k`(default).
27+
28+
## Test Results
29+
30+
### Transfomer
31+
32+
- Multi30k: we achieve BLEU score 35.41 with default setting on Multi30k dataset, without using pre-trained embeddings. (if we set the number of layers to 2, the BLEU score could reach 36.45).
33+
- WMT14: work in progress
34+
35+
### Universal Transformer
36+
37+
- work in progress
38+
39+
## Notes
40+
41+
- Currently we do not support Multi-GPU training(this will be fixed soon), you should only specifiy only one gpu\_id when running the training script.
42+
43+
## Reference
44+
45+
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)
46+
- [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from .graph import *
2+
from .fields import *
3+
from .utils import prepare_dataset
4+
import os
5+
import numpy as np
6+
7+
class ClassificationDataset:
8+
"Dataset class for classification task."
9+
def __init__(self):
10+
raise NotImplementedError
11+
12+
class TranslationDataset:
13+
'''
14+
Dataset class for translation task.
15+
By default, the source language shares the same vocabulary with the target language.
16+
'''
17+
INIT_TOKEN = '<sos>'
18+
EOS_TOKEN = '<eos>'
19+
PAD_TOKEN = '<pad>'
20+
MAX_LENGTH = 50
21+
def __init__(self, path, exts, train='train', valid='valid', test='test', vocab='vocab.txt', replace_oov=None):
22+
vocab_path = os.path.join(path, vocab)
23+
self.src = {}
24+
self.tgt = {}
25+
with open(os.path.join(path, train + '.' + exts[0]), 'r') as f:
26+
self.src['train'] = f.readlines()
27+
with open(os.path.join(path, train + '.' + exts[1]), 'r') as f:
28+
self.tgt['train'] = f.readlines()
29+
with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f:
30+
self.src['valid'] = f.readlines()
31+
with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f:
32+
self.tgt['valid'] = f.readlines()
33+
with open(os.path.join(path, test + '.' + exts[0]), 'r') as f:
34+
self.src['test'] = f.readlines()
35+
with open(os.path.join(path, test + '.' + exts[1]), 'r') as f:
36+
self.tgt['test'] = f.readlines()
37+
38+
if not os.path.exists(vocab_path):
39+
self._make_vocab(vocab_path)
40+
41+
vocab = Vocab(init_token=self.INIT_TOKEN,
42+
eos_token=self.EOS_TOKEN,
43+
pad_token=self.PAD_TOKEN,
44+
unk_token=replace_oov)
45+
vocab.load(vocab_path)
46+
self.vocab = vocab
47+
strip_func = lambda x: x[:self.MAX_LENGTH]
48+
self.src_field = Field(vocab,
49+
preprocessing=None,
50+
postprocessing=strip_func)
51+
self.tgt_field = Field(vocab,
52+
preprocessing=lambda seq: [self.INIT_TOKEN] + seq + [self.EOS_TOKEN],
53+
postprocessing=strip_func)
54+
55+
def get_seq_by_id(self, idx, mode='train', field='src'):
56+
"get raw sequence in dataset by specifying index, mode(train/valid/test), field(src/tgt)"
57+
if field == 'src':
58+
return self.src[mode][idx].strip().split()
59+
else:
60+
return [self.INIT_TOKEN] + self.tgt[mode][idx].strip().split() + [self.EOS_TOKEN]
61+
62+
def _make_vocab(self, path, thres=2):
63+
word_dict = {}
64+
for mode in ['train', 'valid', 'test']:
65+
for line in self.src[mode] + self.tgt[mode]:
66+
for token in line.strip().split():
67+
if token not in word_dict:
68+
word_dict[token] = 0
69+
else:
70+
word_dict[token] += 1
71+
72+
with open(path, 'w') as f:
73+
for k, v in word_dict.items():
74+
if v > 2:
75+
print(k, file=f)
76+
77+
@property
78+
def vocab_size(self):
79+
return len(self.vocab)
80+
81+
@property
82+
def pad_id(self):
83+
return self.vocab[self.PAD_TOKEN]
84+
85+
@property
86+
def sos_id(self):
87+
return self.vocab[self.INIT_TOKEN]
88+
89+
@property
90+
def eos_id(self):
91+
return self.vocab[self.EOS_TOKEN]
92+
93+
def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']):
94+
'''
95+
Create a batched graph correspond to the mini-batch of the dataset.
96+
args:
97+
graph_pool: a GraphPool object for accelerating.
98+
mode: train/valid/test
99+
batch_size: batch size
100+
devices: ['cpu'] or a list of gpu ids.
101+
k: beam size(only required for test)
102+
'''
103+
dev_id, gs = 0, []
104+
src_data, tgt_data = self.src[mode], self.tgt[mode]
105+
n = len(src_data)
106+
order = np.random.permutation(n) if mode == 'train' else range(n)
107+
src_buf, tgt_buf = [], []
108+
109+
for idx in order:
110+
src_sample = self.src_field(
111+
src_data[idx].strip().split())
112+
tgt_sample = self.tgt_field(
113+
tgt_data[idx].strip().split())
114+
src_buf.append(src_sample)
115+
tgt_buf.append(tgt_sample)
116+
if len(src_buf) == batch_size:
117+
if mode == 'test':
118+
assert len(devices) == 1 # we only allow single gpu for inference
119+
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
120+
else:
121+
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
122+
dev_id += 1
123+
if dev_id == len(devices):
124+
yield gs if len(devices) > 1 else gs[0]
125+
dev_id, gs = 0, []
126+
src_buf, tgt_buf = [], []
127+
128+
if len(src_buf) != 0:
129+
if mode == 'test':
130+
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
131+
else:
132+
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
133+
yield gs if len(devices) > 1 else gs[0]
134+
135+
def get_sequence(self, batch):
136+
"return a list of sequence from a list of index arrays"
137+
ret = []
138+
filter_list = set([self.pad_id, self.sos_id, self.eos_id])
139+
for seq in batch:
140+
try:
141+
l = seq.index(self.eos_id)
142+
except:
143+
l = len(seq)
144+
ret.append(' '.join(self.vocab[token] for token in seq[:l] if not token in filter_list))
145+
return ret
146+
147+
def get_dataset(dataset):
148+
"we wrapped a set of datasets as example"
149+
prepare_dataset(dataset)
150+
if dataset == 'babi':
151+
raise NotImplementedError
152+
elif dataset == 'copy' or dataset == 'sort':
153+
return TranslationDataset(
154+
'data/{}'.format(dataset),
155+
('in', 'out'),
156+
train='train',
157+
valid='valid',
158+
test='test',
159+
)
160+
elif dataset == 'multi30k':
161+
return TranslationDataset(
162+
'data/multi30k',
163+
('en.atok', 'de.atok'),
164+
train='train',
165+
valid='val',
166+
test='test2016',
167+
replace_oov='<unk>'
168+
)
169+
elif dataset == 'wmt14':
170+
return TranslationDataset(
171+
'data/wmt14',
172+
('en', 'de'),
173+
train='train.tok.clean.bpe.32000',
174+
valid='newstest2013.tok.bpe.32000',
175+
test='newstest2014.tok.bpe.32000',
176+
vocab='vocab.bpe.32000')
177+
else:
178+
raise KeyError()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
class Vocab:
2+
def __init__(self, init_token=None, eos_token=None, pad_token=None, unk_token=None):
3+
self.init_token = init_token
4+
self.eos_token = eos_token
5+
self.pad_token = pad_token
6+
self.unk_token = unk_token
7+
self.vocab_lst = []
8+
self.vocab_dict = None
9+
10+
def load(self, path):
11+
if self.init_token is not None:
12+
self.vocab_lst.append(self.init_token)
13+
if self.eos_token is not None:
14+
self.vocab_lst.append(self.eos_token)
15+
if self.pad_token is not None:
16+
self.vocab_lst.append(self.pad_token)
17+
if self.unk_token is not None:
18+
self.vocab_lst.append(self.unk_token)
19+
with open(path, 'r') as f:
20+
for token in f.readlines():
21+
token = token.strip()
22+
self.vocab_lst.append(token)
23+
self.vocab_dict = {
24+
v: k for k, v in enumerate(self.vocab_lst)
25+
}
26+
27+
def __len__(self):
28+
return len(self.vocab_lst)
29+
30+
def __getitem__(self, key):
31+
if isinstance(key, str):
32+
if key in self.vocab_dict:
33+
return self.vocab_dict[key]
34+
else:
35+
return self.vocab_dict[self.unk_token]
36+
else:
37+
return self.vocab_lst[key]
38+
39+
class Field:
40+
def __init__(self, vocab, preprocessing=None, postprocessing=None):
41+
self.vocab = vocab
42+
self.preprocessing = preprocessing
43+
self.postprocessing = postprocessing
44+
45+
def preprocess(self, x):
46+
if self.preprocessing is not None:
47+
return self.preprocessing(x)
48+
return x
49+
50+
def postprocess(self, x):
51+
if self.postprocessing is not None:
52+
return self.postprocessing(x)
53+
return x
54+
55+
def numericalize(self, x):
56+
return [self.vocab[token] for token in x]
57+
58+
def __call__(self, x):
59+
return self.postprocess(
60+
self.numericalize(
61+
self.preprocess(x)
62+
)
63+
)

0 commit comments

Comments
 (0)