forked from salesforce/CodeT5
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_gen.py
390 lines (347 loc) · 19.5 KB
/
run_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""
import os
import torch
import logging
import argparse
import math
import numpy as np
from tqdm import tqdm
import multiprocessing
import time
import sys
import pdb
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from models import build_or_load_gen_model
from evaluator import smooth_bleu
from evaluator.CodeBLEU import calc_code_bleu
from evaluator.bleu import _bleu
from utils import get_filenames, get_elapse_time, load_and_cache_gen_data
from configs import add_args, set_seed, set_dist
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer):
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size,
num_workers=4, pin_memory=True)
# Start evaluating model
logger.info(" " + "***** Running ppl evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
model.eval()
eval_loss, batch_num = 0, 0
for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval ppl"):
batch = tuple(t.to(args.device) for t in batch)
source_ids, target_ids = batch
source_mask = source_ids.ne(tokenizer.pad_token_id)
target_mask = target_ids.ne(tokenizer.pad_token_id)
with torch.no_grad():
if args.model_type == 'roberta':
loss, _, _ = model(source_ids=source_ids, source_mask=source_mask,
target_ids=target_ids, target_mask=target_mask)
else:
outputs = model(input_ids=source_ids, attention_mask=source_mask,
labels=target_ids, decoder_attention_mask=target_mask)
loss = outputs.loss
eval_loss += loss.item()
batch_num += 1
eval_loss = eval_loss / batch_num
eval_ppl = round(np.exp(eval_loss), 5)
return eval_ppl
def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag, criteria):
logger.info(" ***** Running bleu evaluation on {} data*****".format(split_tag))
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_sampler = SequentialSampler(eval_data)
if args.data_num == -1:
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size,
num_workers=4, pin_memory=True)
else:
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
pred_ids = []
bleu, codebleu = 0.0, 0.0
for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)):
source_ids = batch[0].to(args.device)
source_mask = source_ids.ne(tokenizer.pad_token_id)
with torch.no_grad():
if args.model_type == 'roberta':
preds = model(source_ids=source_ids, source_mask=source_mask)
top_preds = [pred[0].cpu().numpy() for pred in preds]
else:
preds = model.generate(source_ids,
attention_mask=source_mask,
use_cache=True,
num_beams=args.beam_size,
early_stopping=args.task == 'summarize',
max_length=args.max_target_length)
top_preds = list(preds.cpu().numpy())
pred_ids.extend(top_preds)
# pdb.set_trace()
pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids]
output_fn = os.path.join(args.res_dir, "test_{}.output".format(criteria))
gold_fn = os.path.join(args.res_dir, "test_{}.gold".format(criteria))
src_fn = os.path.join(args.res_dir, "test_{}.src".format(criteria))
if args.task in ['defect']:
target_dict = {0: 'false', 1: 'true'}
golds = [target_dict[ex.target] for ex in eval_examples]
eval_acc = np.mean([int(p == g) for p, g in zip(pred_nls, golds)])
result = {'em': eval_acc, 'bleu': 0, 'codebleu': 0}
with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2:
for pred_nl, gold in zip(pred_nls, eval_examples):
f.write(pred_nl.strip() + '\n')
f1.write(target_dict[gold.target] + '\n')
f2.write(gold.source.strip() + '\n')
logger.info("Save the predictions into %s", output_fn)
else:
dev_accs, predictions = [], []
with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2:
for pred_nl, gold in zip(pred_nls, eval_examples):
dev_accs.append(pred_nl.strip() == gold.target.strip())
if args.task in ['summarize']:
predictions.append(str(gold.idx) + '\t' + pred_nl)
f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n')
f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n')
f2.write(str(gold.idx) + '\t' + gold.source.strip() + '\n')
else:
f.write(pred_nl.strip() + '\n')
f1.write(gold.target.strip() + '\n')
f2.write(gold.source.strip() + '\n')
if args.task in ['summarize']:
(goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn)
bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2)
else:
bleu = round(_bleu(gold_fn, output_fn), 2)
if split_tag == 'test' and args.task in ['refine', 'translate', 'concode']:
codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, args.lang)
# except:
# bleu = 0.0
# codebleu = 0.0
em = np.mean(dev_accs) * 100
result = {'em': em, 'bleu': bleu}
if not args.task == 'summarize' and split_tag == 'test':
result['codebleu'] = codebleu * 100
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(round(result[key], 4)))
return result
def main():
parser = argparse.ArgumentParser()
args = add_args(parser)
logger.info(args)
t0 = time.time()
set_dist(args)
set_seed(args)
config, model, tokenizer = build_or_load_gen_model(args)
model.to(args.device)
if args.n_gpu > 1:
# for DataParallel
model = torch.nn.DataParallel(model)
pool = multiprocessing.Pool(args.cpu_cont)
args.train_filename, args.dev_filename, args.test_filename = get_filenames(args.data_dir, args.task, args.sub_task)
fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+')
if args.do_train:
if args.local_rank in [-1, 0] and args.data_num == -1:
summary_fn = '{}/{}'.format(args.summary_dir, '/'.join(args.output_dir.split('/')[1:]))
tb_writer = SummaryWriter(summary_fn)
# Prepare training data loader
train_examples, train_data = load_and_cache_gen_data(args, args.train_filename, pool, tokenizer, 'train')
train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size,
num_workers=4, pin_memory=True)
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
num_train_optimization_steps = args.num_train_epochs * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=num_train_optimization_steps)
# Start training
train_example_num = len(train_data)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_example_num)
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Batch num = %d", math.ceil(train_example_num / args.train_batch_size))
logger.info(" Num epoch = %d", args.num_train_epochs)
dev_dataset = {}
global_step, best_bleu_em, best_ppl = 0, -1, 1e6
not_loss_dec_cnt, not_bleu_em_inc_cnt = 0, 0 if args.do_eval_bleu else 1e6
for cur_epoch in range(args.start_epoch, int(args.num_train_epochs)):
bar = tqdm(train_dataloader, total=len(train_dataloader), desc="Training")
nb_tr_examples, nb_tr_steps, tr_loss = 0, 0, 0
model.train()
for step, batch in enumerate(bar):
batch = tuple(t.to(args.device) for t in batch)
source_ids, target_ids = batch
source_mask = source_ids.ne(tokenizer.pad_token_id)
target_mask = target_ids.ne(tokenizer.pad_token_id)
if args.model_type == 'roberta':
loss, _, _ = model(source_ids=source_ids, source_mask=source_mask,
target_ids=target_ids, target_mask=target_mask)
else:
outputs = model(input_ids=source_ids, attention_mask=source_mask,
labels=target_ids, decoder_attention_mask=target_mask)
loss = outputs.loss
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
tr_loss += loss.item()
nb_tr_examples += source_ids.size(0)
nb_tr_steps += 1
loss.backward()
if nb_tr_steps % args.gradient_accumulation_steps == 0:
# Update parameters
optimizer.step()
optimizer.zero_grad()
scheduler.step()
global_step += 1
train_loss = round(tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1), 4)
bar.set_description("[{}] Train loss {}".format(cur_epoch, round(train_loss, 3)))
if args.do_eval:
# Eval model with dev dataset
if 'dev_loss' in dev_dataset:
eval_examples, eval_data = dev_dataset['dev_loss']
else:
eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev')
dev_dataset['dev_loss'] = eval_examples, eval_data
eval_ppl = eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer)
result = {'epoch': cur_epoch, 'global_step': global_step, 'eval_ppl': eval_ppl}
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
logger.info(" " + "*" * 20)
if args.data_num == -1:
tb_writer.add_scalar('dev_ppl', eval_ppl, cur_epoch)
# save last checkpoint
if args.save_last_checkpoints:
last_output_dir = os.path.join(args.output_dir, 'checkpoint-last')
if not os.path.exists(last_output_dir):
os.makedirs(last_output_dir)
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(last_output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Save the last model into %s", output_model_file)
if eval_ppl < best_ppl:
not_loss_dec_cnt = 0
logger.info(" Best ppl:%s", eval_ppl)
logger.info(" " + "*" * 20)
fa.write("[%d] Best ppl changed into %.4f\n" % (cur_epoch, eval_ppl))
best_ppl = eval_ppl
# Save best checkpoint for best ppl
output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if args.always_save_model:
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Save the best ppl model into %s", output_model_file)
else:
not_loss_dec_cnt += 1
logger.info("Ppl does not decrease for %d epochs", not_loss_dec_cnt)
if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]):
early_stop_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % (
cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt)
logger.info(early_stop_str)
fa.write(early_stop_str)
break
logger.info("***** CUDA.empty_cache() *****")
torch.cuda.empty_cache()
if args.do_eval_bleu:
eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev',
only_src=True, is_sample=True)
result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'dev', 'e%d' % cur_epoch)
dev_bleu, dev_em = result['bleu'], result['em']
if args.task in ['summarize']:
dev_bleu_em = dev_bleu
elif args.task in ['defect']:
dev_bleu_em = dev_em
else:
dev_bleu_em = dev_bleu + dev_em
if args.data_num == -1:
tb_writer.add_scalar('dev_bleu_em', dev_bleu_em, cur_epoch)
# tb_writer.add_scalar('dev_em', dev_em, cur_epoch)
if dev_bleu_em > best_bleu_em:
not_bleu_em_inc_cnt = 0
logger.info(" [%d] Best bleu+em: %.2f (bleu: %.2f, em: %.2f)",
cur_epoch, dev_bleu_em, dev_bleu, dev_em)
logger.info(" " + "*" * 20)
best_bleu_em = dev_bleu_em
fa.write("[%d] Best bleu+em changed into %.2f (bleu: %.2f, em: %.2f)\n" % (
cur_epoch, best_bleu_em, dev_bleu, dev_em))
# Save best checkpoint for best bleu
output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if args.data_num == -1 or args.always_save_model:
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Save the best bleu model into %s", output_model_file)
else:
not_bleu_em_inc_cnt += 1
logger.info("Bleu does not increase for %d epochs", not_bleu_em_inc_cnt)
if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]):
stop_early_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % (
cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt)
logger.info(stop_early_str)
fa.write(stop_early_str)
break
logger.info("***** CUDA.empty_cache() *****")
torch.cuda.empty_cache()
if args.local_rank in [-1, 0] and args.data_num == -1:
tb_writer.close()
logger.info("Finish training and take %s", get_elapse_time(t0))
if args.do_test:
logger.info(" " + "***** Testing *****")
logger.info(" Batch size = %d", args.eval_batch_size)
for criteria in ['best-bleu', 'best-ppl']: # 'best-bleu', 'best-ppl', 'last'
file = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
logger.info("Reload model from {}".format(file))
model.load_state_dict(torch.load(file))
eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test',
only_src=True, is_sample=False)
result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test', criteria)
test_bleu, test_em = result['bleu'], result['em']
test_codebleu = result['codebleu'] if 'codebleu' in result else 0
result_str = "[%s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % (criteria, test_bleu, test_em, test_codebleu)
logger.info(result_str)
fa.write(result_str)
if args.res_fn:
with open(args.res_fn, 'a+') as f:
f.write('[Time: {}] {}\n'.format(get_elapse_time(t0), file))
f.write(result_str)
logger.info("Finish and take {}".format(get_elapse_time(t0)))
fa.write("Finish and take {}".format(get_elapse_time(t0)))
fa.close()
if __name__ == "__main__":
# print(' '.join(sys.argv[:]))
main()