Skip to content

Commit f9faae4

Browse files
committedOct 19, 2019
update pretraining
1 parent b604f19 commit f9faae4

File tree

18 files changed

+100
-74
lines changed

18 files changed

+100
-74
lines changed
 

‎datasets/conala/dataset.py

+72-62
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pickle
55

66
from components.action_info import get_action_infos
7+
from datasets.conala.evaluator import ConalaEvaluator
78
from datasets.conala.util import *
89
from asdl.lang.py3.py3_transition_system import python_ast_to_asdl_ast, asdl_ast_to_python_ast, Python3TransitionSystem
910

@@ -16,7 +17,8 @@
1617
from components.action_info import ActionInfo
1718

1819

19-
def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, code_freq=3):
20+
def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, code_freq=3,
21+
mined_data_file=None, num_mined=0):
2022
np.random.seed(1234)
2123

2224
asdl_text = open(grammar_file).read()
@@ -32,19 +34,12 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
3234
dev_examples = train_examples[:200]
3335
train_examples = train_examples[200:]
3436

35-
# full_train_examples = train_examples[:]
36-
# np.random.shuffle(train_examples)
37-
# dev_examples = []
38-
# dev_questions = set()
39-
# dev_examples_id = []
40-
# for i, example in enumerate(full_train_examples):
41-
# qid = example.meta['example_dict']['question_id']
42-
# if qid not in dev_questions and len(dev_examples) < 200:
43-
# dev_questions.add(qid)
44-
# dev_examples.append(example)
45-
# dev_examples_id.append(i)
46-
47-
# train_examples = [e for i, e in enumerate(full_train_examples) if i not in dev_examples_id]
37+
if mined_data_file and num_mined > 0:
38+
print("use mined data: ", num_mined)
39+
mined_examples = preprocess_dataset(mined_data_file, name='mined', transition_system=transition_system,
40+
firstk=num_mined)
41+
train_examples += mined_examples
42+
4843
print(f'{len(train_examples)} training instances', file=sys.stderr)
4944
print(f'{len(dev_examples)} dev instances', file=sys.stderr)
5045

@@ -71,58 +66,65 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
7166
print('Avg action len: %d' % np.average(action_lens), file=sys.stderr)
7267
print('Actions larger than 100: %d' % len(list(filter(lambda x: x > 100, action_lens))), file=sys.stderr)
7368

74-
pickle.dump(train_examples, open('data/conala/train.var_str_sep.bin', 'wb'))
75-
pickle.dump(full_train_examples, open('data/conala/train.var_str_sep.full.bin', 'wb'))
76-
pickle.dump(dev_examples, open('data/conala/dev.var_str_sep.bin', 'wb'))
77-
pickle.dump(test_examples, open('data/conala/test.var_str_sep.bin', 'wb'))
78-
pickle.dump(vocab, open('data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.bin' % (src_freq, code_freq), 'wb'))
69+
pickle.dump(train_examples, open('data/conala/train.var_str_sep.mined_{}.bin'.format(num_mined), 'wb'))
70+
pickle.dump(full_train_examples, open('data/conala/train.var_str_sep.full.mined_{}.bin'.format(num_mined), 'wb'))
71+
pickle.dump(dev_examples, open('data/conala/dev.var_str_sep.mined_{}.bin'.format(num_mined), 'wb'))
72+
pickle.dump(test_examples, open('data/conala/test.var_str_sep.mined_{}.bin'.format(num_mined), 'wb'))
73+
pickle.dump(vocab, open('data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.mined_%s.bin' % (src_freq, code_freq, num_mined), 'wb'))
7974

8075

81-
def preprocess_dataset(file_path, transition_system, name='train'):
82-
dataset = json.load(open(file_path))
76+
def preprocess_dataset(file_path, transition_system, name='train', firstk=None):
77+
try:
78+
dataset = json.load(open(file_path))
79+
except:
80+
dataset = [json.loads(jline) for jline in open(file_path).readlines()]
8381
examples = []
8482
evaluator = ConalaEvaluator(transition_system)
8583

8684
f = open(file_path + '.debug', 'w')
8785

8886
for i, example_json in enumerate(dataset):
89-
example_dict = preprocess_example(example_json)
90-
if example_json['question_id'] in (18351951, 9497290, 19641579, 32283692):
91-
print(example_json['question_id'])
87+
if firstk and i >= firstk:
88+
break
89+
try:
90+
example_dict = preprocess_example(example_json)
91+
if example_json['question_id'] in (18351951, 9497290, 19641579, 32283692):
92+
print(example_json['question_id'])
93+
continue
94+
95+
python_ast = ast.parse(example_dict['canonical_snippet'])
96+
canonical_code = astor.to_source(python_ast).strip()
97+
tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
98+
tgt_actions = transition_system.get_actions(tgt_ast)
99+
100+
# sanity check
101+
hyp = Hypothesis()
102+
for t, action in enumerate(tgt_actions):
103+
assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
104+
if isinstance(action, ApplyRuleAction):
105+
assert action.production in transition_system.get_valid_continuating_productions(hyp)
106+
107+
p_t = -1
108+
f_t = None
109+
if hyp.frontier_node:
110+
p_t = hyp.frontier_node.created_time
111+
f_t = hyp.frontier_field.field.__repr__(plain=True)
112+
113+
# print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
114+
hyp = hyp.clone_and_apply_action(action)
115+
116+
assert hyp.frontier_node is None and hyp.frontier_field is None
117+
hyp.code = code_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip()
118+
assert code_from_hyp == canonical_code
119+
120+
decanonicalized_code_from_hyp = decanonicalize_code(code_from_hyp, example_dict['slot_map'])
121+
assert compare_ast(ast.parse(example_json['snippet']), ast.parse(decanonicalized_code_from_hyp))
122+
assert transition_system.compare_ast(transition_system.surface_code_to_ast(decanonicalized_code_from_hyp),
123+
transition_system.surface_code_to_ast(example_json['snippet']))
124+
125+
tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions)
126+
except:
92127
continue
93-
94-
python_ast = ast.parse(example_dict['canonical_snippet'])
95-
canonical_code = astor.to_source(python_ast).strip()
96-
tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
97-
tgt_actions = transition_system.get_actions(tgt_ast)
98-
99-
# sanity check
100-
hyp = Hypothesis()
101-
for t, action in enumerate(tgt_actions):
102-
assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
103-
if isinstance(action, ApplyRuleAction):
104-
assert action.production in transition_system.get_valid_continuating_productions(hyp)
105-
106-
p_t = -1
107-
f_t = None
108-
if hyp.frontier_node:
109-
p_t = hyp.frontier_node.created_time
110-
f_t = hyp.frontier_field.field.__repr__(plain=True)
111-
112-
# print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
113-
hyp = hyp.clone_and_apply_action(action)
114-
115-
assert hyp.frontier_node is None and hyp.frontier_field is None
116-
hyp.code = code_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip()
117-
assert code_from_hyp == canonical_code
118-
119-
decanonicalized_code_from_hyp = decanonicalize_code(code_from_hyp, example_dict['slot_map'])
120-
assert compare_ast(ast.parse(example_json['snippet']), ast.parse(decanonicalized_code_from_hyp))
121-
assert transition_system.compare_ast(transition_system.surface_code_to_ast(decanonicalized_code_from_hyp),
122-
transition_system.surface_code_to_ast(example_json['snippet']))
123-
124-
tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions)
125-
126128
example = Example(idx=f'{i}-{example_json["question_id"]}',
127129
src_sent=example_dict['intent_tokens'],
128130
tgt_actions=tgt_action_infos,
@@ -136,7 +138,10 @@ def preprocess_dataset(file_path, transition_system, name='train'):
136138

137139
# log!
138140
f.write(f'Example: {example.idx}\n')
139-
f.write(f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n")
141+
if 'rewritten_intent' in example.meta['example_dict']:
142+
f.write(f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n")
143+
else:
144+
f.write(f"Original Utterance: {example.meta['example_dict']['intent']}\n")
140145
f.write(f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
141146
f.write(f"\n")
142147
f.write(f"Utterance: {' '.join(example.src_sent)}\n")
@@ -150,9 +155,11 @@ def preprocess_dataset(file_path, transition_system, name='train'):
150155

151156
def preprocess_example(example_json):
152157
intent = example_json['intent']
153-
rewritten_intent = example_json['rewritten_intent']
158+
if 'rewritten_intent' in example_json:
159+
rewritten_intent = example_json['rewritten_intent']
160+
else:
161+
rewritten_intent = None
154162
snippet = example_json['snippet']
155-
question_id = example_json['question_id']
156163

157164
if rewritten_intent is None:
158165
rewritten_intent = intent
@@ -190,8 +197,11 @@ def generate_vocab_for_paraphrase_model(vocab_path, save_path):
190197

191198
if __name__ == '__main__':
192199
# the json files can be download from http://conala-corpus.github.io
193-
preprocess_conala_dataset(train_file='data/conala/conala-train.json',
200+
for num in (10000, 20000):
201+
preprocess_conala_dataset(train_file='data/conala/conala-train.json',
194202
test_file='data/conala/conala-test.json',
195-
grammar_file='asdl/lang/py3/py3_asdl.simplified.txt', src_freq=3, code_freq=3)
203+
mined_data_file='data/conala/conala-mined.jsonl',
204+
grammar_file='asdl/lang/py3/py3_asdl.simplified.txt',
205+
src_freq=3, code_freq=3, num_mined=num)
196206

197207
# generate_vocab_for_paraphrase_model('data/conala/vocab.src_freq3.code_freq3.bin', 'data/conala/vocab.para.src_freq3.code_freq3.bin')

‎datasets/conala/evaluator.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import csv
2+
13
from components.evaluator import Evaluator
24
from common.registerable import Registrable
35
from components.dataset import Dataset
@@ -32,7 +34,10 @@ def get_sentence_bleu(self, example, hyp):
3234
tokenize_for_bleu_eval(hyp.decanonical_code),
3335
smoothing_function=SmoothingFunction().method3)
3436

35-
def evaluate_dataset(self, dataset, decode_results, fast_mode=False):
37+
38+
def evaluate_dataset(self, dataset, decode_results, fast_mode=False, args=None):
39+
if args.save_decode_to:
40+
csv_writer = csv.writer(open(args.save_decode_to + '.csv', 'w'))
3641
examples = dataset.examples if isinstance(dataset, Dataset) else dataset
3742
assert len(examples) == len(decode_results)
3843

@@ -88,7 +93,11 @@ def evaluate_dataset(self, dataset, decode_results, fast_mode=False):
8893

8994
top_decanonical_code_tokens = hyp_list[0].decanonical_code_tokens
9095
sent_bleu_score = hyp_list[0].bleu_score
91-
96+
# write results to file
97+
if args.save_decode_to:
98+
csv_writer.writerow([" ".join(example.src_sent),
99+
" ".join(example.reference_code_tokens),
100+
" ".join(top_decanonical_code_tokens)])
92101
best_hyp_idx = np.argmax(example_hyp_bleu_scores)
93102
oracle_sent_bleu = example_hyp_bleu_scores[best_hyp_idx]
94103
_best_hyp_code_tokens = hyp_list[best_hyp_idx].decanonical_code_tokens

‎evaluation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def decode(examples, model, args, verbose=False, **kwargs):
5858
def evaluate(examples, parser, evaluator, args, verbose=False, return_decode_result=False, eval_top_pred_only=False):
5959
decode_results = decode(examples, parser, args, verbose=verbose)
6060

61-
eval_result = evaluator.evaluate_dataset(examples, decode_results, fast_mode=eval_top_pred_only)
61+
eval_result = evaluator.evaluate_dataset(examples, decode_results, fast_mode=eval_top_pred_only, args=args)
6262

6363
if return_decode_result:
6464
return eval_result, decode_results

‎exp.py

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def train(args):
145145
model.save(model_file)
146146

147147
# perform validation
148+
is_better = False
148149
if args.dev_file:
149150
if epoch % args.valid_every_epoch == 0:
150151
print('[Epoch %d] begin validation' % epoch, file=sys.stderr)

‎model/nn_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,11 @@ def to_input_variable(sequences, vocab, cuda=False, training=True, append_bounda
8787

8888
word_ids = word2id(sequences, vocab)
8989
sents_t = input_transpose(word_ids, vocab['<pad>'])
90-
91-
sents_var = Variable(torch.LongTensor(sents_t), volatile=(not training), requires_grad=False)
90+
if training:
91+
sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False)
92+
else:
93+
with torch.no_grad():
94+
sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False)
9295
if cuda:
9396
sents_var = sents_var.cuda()
9497

‎model/parser.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ def parse(self, src_sent, context=None, beam_size=5, debug=False):
503503

504504
zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())
505505

506-
hyp_scores = Variable(self.new_tensor([0.]), volatile=True)
506+
with torch.no_grad():
507+
hyp_scores = Variable(self.new_tensor([0.]))
507508

508509
# For computing copy probabilities, we marginalize over tokens with the same surface form
509510
# `aggregated_primitive_tokens` stores the position of occurrence of each source token
@@ -525,7 +526,8 @@ def parse(self, src_sent, context=None, beam_size=5, debug=False):
525526
exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))
526527

527528
if t == 0:
528-
x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)
529+
with torch.no_grad():
530+
x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_())
529531
if args.no_parent_field_type_embed is False:
530532
offset = args.action_embed_size # prev_action
531533
offset += args.att_vec_size * (not args.no_input_feed)

‎scripts/atis/train.sh

100755100644
File mode changed.

‎scripts/conala/train.sh

100755100644
+6-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#!/bin/bash
22
set -e
33

4-
seed=${1:-0}
5-
vocab="data/conala/vocab.var_str_sep.src_freq3.code_freq3.bin"
6-
train_file="data/conala/train.var_str_sep.bin"
7-
dev_file="data/conala/dev.var_str_sep.bin"
4+
seed=0
5+
mined_num=$1
6+
vocab="data/conala/vocab.var_str_sep.new_dev.src_freq3.code_freq3.mined_${mined_num}.bin"
7+
train_file="data/conala/train.var_str_sep.mined_${mined_num}.bin"
8+
dev_file="data/conala/dev.var_str_sep.mined_${mined_num}.bin"
89
dropout=0.3
910
hidden_size=256
1011
embed_size=128
@@ -17,7 +18,7 @@ lr_decay=0.5
1718
beam_size=15
1819
lstm='lstm' # lstm
1920
lr_decay_after_epoch=15
20-
model_name=model.sup.conala.${lstm}.hidden${hidden_size}.embed${embed_size}.action${action_embed_size}.field${field_embed_size}.type${type_embed_size}.dr${dropout}.lr${lr}.lr_de${lr_decay}.lr_da${lr_decay_after_epoch}.beam${beam_size}.$(basename ${vocab}).$(basename ${train_file}).glorot.par_state.seed${seed}
21+
model_name=model.sup.conala.${lstm}.hidden${hidden_size}.embed${embed_size}.action${action_embed_size}.field${field_embed_size}.type${type_embed_size}.dr${dropout}.lr${lr}.lr_de${lr_decay}.lr_da${lr_decay_after_epoch}.beam${beam_size}.$(basename ${vocab}).$(basename ${train_file}).glorot.par_state.seed${seed}.mined_${mined_num}
2122

2223
echo "**** Writing results to logs/conala/${model_name}.log ****"
2324
mkdir -p logs/conala

‎scripts/django/train.sh

100755100644
File mode changed.

‎scripts/geo/test.sh

100755100644
File mode changed.

‎scripts/geo/train.sh

100755100644
File mode changed.

‎scripts/jobs/test.sh

100755100644
File mode changed.

‎scripts/jobs/train.sh

100755100644
File mode changed.

‎scripts/wikisql/test.sh

100755100644
File mode changed.

‎scripts/wikisql/train.sh

100755100644
File mode changed.

‎server/static/d3Tree.js

100755100644
File mode changed.

‎server/static/parser.js

100755100644
File mode changed.

‎server/static/tree-viewer.css

100755100644
File mode changed.

0 commit comments

Comments
 (0)
Please sign in to comment.