4
4
import pickle
5
5
6
6
from components .action_info import get_action_infos
7
+ from datasets .conala .evaluator import ConalaEvaluator
7
8
from datasets .conala .util import *
8
9
from asdl .lang .py3 .py3_transition_system import python_ast_to_asdl_ast , asdl_ast_to_python_ast , Python3TransitionSystem
9
10
16
17
from components .action_info import ActionInfo
17
18
18
19
19
- def preprocess_conala_dataset (train_file , test_file , grammar_file , src_freq = 3 , code_freq = 3 ):
20
+ def preprocess_conala_dataset (train_file , test_file , grammar_file , src_freq = 3 , code_freq = 3 ,
21
+ mined_data_file = None , num_mined = 0 ):
20
22
np .random .seed (1234 )
21
23
22
24
asdl_text = open (grammar_file ).read ()
@@ -32,19 +34,12 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
32
34
dev_examples = train_examples [:200 ]
33
35
train_examples = train_examples [200 :]
34
36
35
- # full_train_examples = train_examples[:]
36
- # np.random.shuffle(train_examples)
37
- # dev_examples = []
38
- # dev_questions = set()
39
- # dev_examples_id = []
40
- # for i, example in enumerate(full_train_examples):
41
- # qid = example.meta['example_dict']['question_id']
42
- # if qid not in dev_questions and len(dev_examples) < 200:
43
- # dev_questions.add(qid)
44
- # dev_examples.append(example)
45
- # dev_examples_id.append(i)
46
-
47
- # train_examples = [e for i, e in enumerate(full_train_examples) if i not in dev_examples_id]
37
+ if mined_data_file and num_mined > 0 :
38
+ print ("use mined data: " , num_mined )
39
+ mined_examples = preprocess_dataset (mined_data_file , name = 'mined' , transition_system = transition_system ,
40
+ firstk = num_mined )
41
+ train_examples += mined_examples
42
+
48
43
print (f'{ len (train_examples )} training instances' , file = sys .stderr )
49
44
print (f'{ len (dev_examples )} dev instances' , file = sys .stderr )
50
45
@@ -71,58 +66,65 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
71
66
print ('Avg action len: %d' % np .average (action_lens ), file = sys .stderr )
72
67
print ('Actions larger than 100: %d' % len (list (filter (lambda x : x > 100 , action_lens ))), file = sys .stderr )
73
68
74
- pickle .dump (train_examples , open ('data/conala/train.var_str_sep.bin' , 'wb' ))
75
- pickle .dump (full_train_examples , open ('data/conala/train.var_str_sep.full.bin' , 'wb' ))
76
- pickle .dump (dev_examples , open ('data/conala/dev.var_str_sep.bin' , 'wb' ))
77
- pickle .dump (test_examples , open ('data/conala/test.var_str_sep.bin' , 'wb' ))
78
- pickle .dump (vocab , open ('data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.bin' % (src_freq , code_freq ), 'wb' ))
69
+ pickle .dump (train_examples , open ('data/conala/train.var_str_sep.mined_{}. bin' . format ( num_mined ) , 'wb' ))
70
+ pickle .dump (full_train_examples , open ('data/conala/train.var_str_sep.full.mined_{}. bin' . format ( num_mined ) , 'wb' ))
71
+ pickle .dump (dev_examples , open ('data/conala/dev.var_str_sep.mined_{}. bin' . format ( num_mined ) , 'wb' ))
72
+ pickle .dump (test_examples , open ('data/conala/test.var_str_sep.mined_{}. bin' . format ( num_mined ) , 'wb' ))
73
+ pickle .dump (vocab , open ('data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.mined_%s. bin' % (src_freq , code_freq , num_mined ), 'wb' ))
79
74
80
75
81
- def preprocess_dataset (file_path , transition_system , name = 'train' ):
82
- dataset = json .load (open (file_path ))
76
+ def preprocess_dataset (file_path , transition_system , name = 'train' , firstk = None ):
77
+ try :
78
+ dataset = json .load (open (file_path ))
79
+ except :
80
+ dataset = [json .loads (jline ) for jline in open (file_path ).readlines ()]
83
81
examples = []
84
82
evaluator = ConalaEvaluator (transition_system )
85
83
86
84
f = open (file_path + '.debug' , 'w' )
87
85
88
86
for i , example_json in enumerate (dataset ):
89
- example_dict = preprocess_example (example_json )
90
- if example_json ['question_id' ] in (18351951 , 9497290 , 19641579 , 32283692 ):
91
- print (example_json ['question_id' ])
87
+ if firstk and i >= firstk :
88
+ break
89
+ try :
90
+ example_dict = preprocess_example (example_json )
91
+ if example_json ['question_id' ] in (18351951 , 9497290 , 19641579 , 32283692 ):
92
+ print (example_json ['question_id' ])
93
+ continue
94
+
95
+ python_ast = ast .parse (example_dict ['canonical_snippet' ])
96
+ canonical_code = astor .to_source (python_ast ).strip ()
97
+ tgt_ast = python_ast_to_asdl_ast (python_ast , transition_system .grammar )
98
+ tgt_actions = transition_system .get_actions (tgt_ast )
99
+
100
+ # sanity check
101
+ hyp = Hypothesis ()
102
+ for t , action in enumerate (tgt_actions ):
103
+ assert action .__class__ in transition_system .get_valid_continuation_types (hyp )
104
+ if isinstance (action , ApplyRuleAction ):
105
+ assert action .production in transition_system .get_valid_continuating_productions (hyp )
106
+
107
+ p_t = - 1
108
+ f_t = None
109
+ if hyp .frontier_node :
110
+ p_t = hyp .frontier_node .created_time
111
+ f_t = hyp .frontier_field .field .__repr__ (plain = True )
112
+
113
+ # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
114
+ hyp = hyp .clone_and_apply_action (action )
115
+
116
+ assert hyp .frontier_node is None and hyp .frontier_field is None
117
+ hyp .code = code_from_hyp = astor .to_source (asdl_ast_to_python_ast (hyp .tree , transition_system .grammar )).strip ()
118
+ assert code_from_hyp == canonical_code
119
+
120
+ decanonicalized_code_from_hyp = decanonicalize_code (code_from_hyp , example_dict ['slot_map' ])
121
+ assert compare_ast (ast .parse (example_json ['snippet' ]), ast .parse (decanonicalized_code_from_hyp ))
122
+ assert transition_system .compare_ast (transition_system .surface_code_to_ast (decanonicalized_code_from_hyp ),
123
+ transition_system .surface_code_to_ast (example_json ['snippet' ]))
124
+
125
+ tgt_action_infos = get_action_infos (example_dict ['intent_tokens' ], tgt_actions )
126
+ except :
92
127
continue
93
-
94
- python_ast = ast .parse (example_dict ['canonical_snippet' ])
95
- canonical_code = astor .to_source (python_ast ).strip ()
96
- tgt_ast = python_ast_to_asdl_ast (python_ast , transition_system .grammar )
97
- tgt_actions = transition_system .get_actions (tgt_ast )
98
-
99
- # sanity check
100
- hyp = Hypothesis ()
101
- for t , action in enumerate (tgt_actions ):
102
- assert action .__class__ in transition_system .get_valid_continuation_types (hyp )
103
- if isinstance (action , ApplyRuleAction ):
104
- assert action .production in transition_system .get_valid_continuating_productions (hyp )
105
-
106
- p_t = - 1
107
- f_t = None
108
- if hyp .frontier_node :
109
- p_t = hyp .frontier_node .created_time
110
- f_t = hyp .frontier_field .field .__repr__ (plain = True )
111
-
112
- # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
113
- hyp = hyp .clone_and_apply_action (action )
114
-
115
- assert hyp .frontier_node is None and hyp .frontier_field is None
116
- hyp .code = code_from_hyp = astor .to_source (asdl_ast_to_python_ast (hyp .tree , transition_system .grammar )).strip ()
117
- assert code_from_hyp == canonical_code
118
-
119
- decanonicalized_code_from_hyp = decanonicalize_code (code_from_hyp , example_dict ['slot_map' ])
120
- assert compare_ast (ast .parse (example_json ['snippet' ]), ast .parse (decanonicalized_code_from_hyp ))
121
- assert transition_system .compare_ast (transition_system .surface_code_to_ast (decanonicalized_code_from_hyp ),
122
- transition_system .surface_code_to_ast (example_json ['snippet' ]))
123
-
124
- tgt_action_infos = get_action_infos (example_dict ['intent_tokens' ], tgt_actions )
125
-
126
128
example = Example (idx = f'{ i } -{ example_json ["question_id" ]} ' ,
127
129
src_sent = example_dict ['intent_tokens' ],
128
130
tgt_actions = tgt_action_infos ,
@@ -136,7 +138,10 @@ def preprocess_dataset(file_path, transition_system, name='train'):
136
138
137
139
# log!
138
140
f .write (f'Example: { example .idx } \n ' )
139
- f .write (f"Original Utterance: { example .meta ['example_dict' ]['rewritten_intent' ]} \n " )
141
+ if 'rewritten_intent' in example .meta ['example_dict' ]:
142
+ f .write (f"Original Utterance: { example .meta ['example_dict' ]['rewritten_intent' ]} \n " )
143
+ else :
144
+ f .write (f"Original Utterance: { example .meta ['example_dict' ]['intent' ]} \n " )
140
145
f .write (f"Original Snippet: { example .meta ['example_dict' ]['snippet' ]} \n " )
141
146
f .write (f"\n " )
142
147
f .write (f"Utterance: { ' ' .join (example .src_sent )} \n " )
@@ -150,9 +155,11 @@ def preprocess_dataset(file_path, transition_system, name='train'):
150
155
151
156
def preprocess_example (example_json ):
152
157
intent = example_json ['intent' ]
153
- rewritten_intent = example_json ['rewritten_intent' ]
158
+ if 'rewritten_intent' in example_json :
159
+ rewritten_intent = example_json ['rewritten_intent' ]
160
+ else :
161
+ rewritten_intent = None
154
162
snippet = example_json ['snippet' ]
155
- question_id = example_json ['question_id' ]
156
163
157
164
if rewritten_intent is None :
158
165
rewritten_intent = intent
@@ -190,8 +197,11 @@ def generate_vocab_for_paraphrase_model(vocab_path, save_path):
190
197
191
198
if __name__ == '__main__' :
192
199
# the json files can be download from http://conala-corpus.github.io
193
- preprocess_conala_dataset (train_file = 'data/conala/conala-train.json' ,
200
+ for num in (10000 , 20000 ):
201
+ preprocess_conala_dataset (train_file = 'data/conala/conala-train.json' ,
194
202
test_file = 'data/conala/conala-test.json' ,
195
- grammar_file = 'asdl/lang/py3/py3_asdl.simplified.txt' , src_freq = 3 , code_freq = 3 )
203
+ mined_data_file = 'data/conala/conala-mined.jsonl' ,
204
+ grammar_file = 'asdl/lang/py3/py3_asdl.simplified.txt' ,
205
+ src_freq = 3 , code_freq = 3 , num_mined = num )
196
206
197
207
# generate_vocab_for_paraphrase_model('data/conala/vocab.src_freq3.code_freq3.bin', 'data/conala/vocab.para.src_freq3.code_freq3.bin')
0 commit comments