|
| 1 | +# coding=utf-8 |
| 2 | + |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import re |
| 6 | +import cPickle as pickle |
| 7 | +import ast |
| 8 | +import astor |
| 9 | +import nltk |
| 10 | +import sys |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from asdl.asdl_ast import RealizedField |
| 15 | +from asdl.lang.py.py_asdl_helper import python_ast_to_asdl_ast, asdl_ast_to_python_ast |
| 16 | +from asdl.lang.py.py_transition_system import PythonTransitionSystem |
| 17 | +from asdl.hypothesis import * |
| 18 | + |
| 19 | +from components.action_info import ActionInfo |
| 20 | + |
| 21 | +p_elif = re.compile(r'^elif\s?') |
| 22 | +p_else = re.compile(r'^else\s?') |
| 23 | +p_try = re.compile(r'^try\s?') |
| 24 | +p_except = re.compile(r'^except\s?') |
| 25 | +p_finally = re.compile(r'^finally\s?') |
| 26 | +p_decorator = re.compile(r'^@.*') |
| 27 | + |
| 28 | +QUOTED_STRING_RE = re.compile(r"(?P<quote>['\"])(?P<string>.*?)(?<!\\)(?P=quote)") |
| 29 | + |
| 30 | + |
| 31 | +class Django(object): |
| 32 | + @staticmethod |
| 33 | + def canonicalize_code(code): |
| 34 | + if p_elif.match(code): |
| 35 | + code = 'if True: pass\n' + code |
| 36 | + |
| 37 | + if p_else.match(code): |
| 38 | + code = 'if True: pass\n' + code |
| 39 | + |
| 40 | + if p_try.match(code): |
| 41 | + code = code + 'pass\nexcept: pass' |
| 42 | + elif p_except.match(code): |
| 43 | + code = 'try: pass\n' + code |
| 44 | + elif p_finally.match(code): |
| 45 | + code = 'try: pass\n' + code |
| 46 | + |
| 47 | + if p_decorator.match(code): |
| 48 | + code = code + '\ndef dummy(): pass' |
| 49 | + |
| 50 | + if code[-1] == ':': |
| 51 | + code = code + 'pass' |
| 52 | + |
| 53 | + return code |
| 54 | + |
| 55 | + @staticmethod |
| 56 | + def canonicalize_query(query): |
| 57 | + """ |
| 58 | + canonicalize the query, replace strings to a special place holder |
| 59 | + """ |
| 60 | + str_count = 0 |
| 61 | + str_map = dict() |
| 62 | + |
| 63 | + matches = QUOTED_STRING_RE.findall(query) |
| 64 | + # de-duplicate |
| 65 | + cur_replaced_strs = set() |
| 66 | + for match in matches: |
| 67 | + # If one or more groups are present in the pattern, |
| 68 | + # it returns a list of groups |
| 69 | + quote = match[0] |
| 70 | + str_literal = quote + match[1] + quote |
| 71 | + |
| 72 | + if str_literal in cur_replaced_strs: |
| 73 | + continue |
| 74 | + |
| 75 | + # FIXME: substitute the ' % s ' with |
| 76 | + if str_literal in ['\'%s\'', '\"%s\"']: |
| 77 | + continue |
| 78 | + |
| 79 | + str_repr = '_STR:%d_' % str_count |
| 80 | + str_map[str_literal] = str_repr |
| 81 | + |
| 82 | + query = query.replace(str_literal, str_repr) |
| 83 | + |
| 84 | + str_count += 1 |
| 85 | + cur_replaced_strs.add(str_literal) |
| 86 | + |
| 87 | + # tokenize |
| 88 | + query_tokens = nltk.word_tokenize(query) |
| 89 | + |
| 90 | + new_query_tokens = [] |
| 91 | + # break up function calls like foo.bar.func |
| 92 | + for token in query_tokens: |
| 93 | + new_query_tokens.append(token) |
| 94 | + i = token.find('.') |
| 95 | + if 0 < i < len(token) - 1: |
| 96 | + new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']'] |
| 97 | + new_query_tokens.extend(new_tokens) |
| 98 | + |
| 99 | + query = ' '.join(new_query_tokens) |
| 100 | + |
| 101 | + return query, str_map |
| 102 | + |
| 103 | + @staticmethod |
| 104 | + def canonicalize_example(query, code): |
| 105 | + |
| 106 | + canonical_query, str_map = Django.canonicalize_query(query) |
| 107 | + query_tokens = canonical_query.split(' ') |
| 108 | + canonical_code = code |
| 109 | + |
| 110 | + for str_literal, str_repr in str_map.iteritems(): |
| 111 | + canonical_code = canonical_code.replace(str_literal, '\'' + str_repr + '\'') |
| 112 | + |
| 113 | + canonical_code = Django.canonicalize_code(canonical_code) |
| 114 | + |
| 115 | + # sanity check |
| 116 | + try: |
| 117 | + gold_ast_tree = ast.parse(canonical_code).body[0] |
| 118 | + except: |
| 119 | + print('error!') |
| 120 | + canonical_code = Django.canonicalize_code(code) |
| 121 | + gold_ast_tree = ast.parse(canonical_code).body[0] |
| 122 | + str_map = {} |
| 123 | + |
| 124 | + # parse_tree = python_ast_to_asdl_ast(gold_ast_tree, grammar) |
| 125 | + # gold_source = astor.to_source(gold_ast_tree) |
| 126 | + # ast_tree = asdl_ast_to_python_ast(parse_tree, grammar) |
| 127 | + # source = astor.to_source(ast_tree) |
| 128 | + |
| 129 | + # assert gold_source == source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, source) |
| 130 | + # |
| 131 | + # # action check |
| 132 | + # parser = PythonTransitionSystem(grammar) |
| 133 | + # actions = parser.get_actions(parse_tree) |
| 134 | + # |
| 135 | + # hyp = Hypothesis() |
| 136 | + # for action in actions: |
| 137 | + # assert action.__class__ in parser.get_valid_continuation_types(hyp) |
| 138 | + # if isinstance(action, ApplyRuleAction): |
| 139 | + # assert action in parser.get_valid_continuations(hyp) |
| 140 | + # hyp.apply_action(action) |
| 141 | + # |
| 142 | + # src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar)) |
| 143 | + # assert src_from_hyp == gold_source |
| 144 | + |
| 145 | + return query_tokens, canonical_code, str_map |
| 146 | + |
| 147 | + @staticmethod |
| 148 | + def parse_django_dataset(annot_file, code_file, asdl_file_path, MAX_QUERY_LENGTH=70): |
| 149 | + asdl_text = open(asdl_file_path).read() |
| 150 | + grammar = ASDLGrammar.from_text(asdl_text) |
| 151 | + transition_system = PythonTransitionSystem(grammar) |
| 152 | + |
| 153 | + loaded_examples = [] |
| 154 | + |
| 155 | + from components.vocab import Vocab, VocabEntry |
| 156 | + from components.dataset import Example |
| 157 | + |
| 158 | + for idx, (src_query, tgt_code) in enumerate(zip(open(annot_file), open(code_file))): |
| 159 | + src_query = src_query.strip() |
| 160 | + tgt_code = tgt_code.strip() |
| 161 | + |
| 162 | + src_query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(src_query, tgt_code) |
| 163 | + python_ast = ast.parse(tgt_canonical_code).body[0] |
| 164 | + gold_source = astor.to_source(python_ast) |
| 165 | + tgt_ast = python_ast_to_asdl_ast(python_ast, grammar) |
| 166 | + tgt_actions = transition_system.get_actions(tgt_ast) |
| 167 | + |
| 168 | + # sanity check |
| 169 | + hyp = Hypothesis() |
| 170 | + for action in tgt_actions: |
| 171 | + assert action.__class__ in transition_system.get_valid_continuation_types(hyp) |
| 172 | + if isinstance(action, ApplyRuleAction): |
| 173 | + assert action in transition_system.get_valid_continuating_productions(hyp) |
| 174 | + hyp.apply_action(action) |
| 175 | + |
| 176 | + src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar)) |
| 177 | + assert src_from_hyp == gold_source |
| 178 | + |
| 179 | + loaded_examples.append({'src_query_tokens': src_query_tokens, |
| 180 | + 'tgt_canonical_code': tgt_canonical_code, |
| 181 | + 'tgt_ast': tgt_ast, |
| 182 | + 'tgt_actions': tgt_actions, |
| 183 | + 'raw_code': tgt_code, 'str_map': str_map}) |
| 184 | + |
| 185 | + print('first pass, processed %d' % idx, file=sys.stderr) |
| 186 | + |
| 187 | + src_vocab = VocabEntry.from_corpus([e['src_query_tokens'] for e in loaded_examples], size=5000, freq_cutoff=3) |
| 188 | + |
| 189 | + primitive_tokens = [map(lambda a: a.token, |
| 190 | + filter(lambda a: isinstance(a, GenTokenAction), e['tgt_actions'])) |
| 191 | + for e in loaded_examples] |
| 192 | + |
| 193 | + primitive_vocab = VocabEntry.from_corpus(primitive_tokens, size=5000, freq_cutoff=3) |
| 194 | + assert '_STR:0_' in primitive_vocab |
| 195 | + |
| 196 | + vocab = Vocab(source=src_vocab, primitive=primitive_vocab) |
| 197 | + print('generated vocabulary %s' % repr(vocab), file=sys.stderr) |
| 198 | + |
| 199 | + train_examples = [] |
| 200 | + dev_examples = [] |
| 201 | + test_examples = [] |
| 202 | + |
| 203 | + action_len = [] |
| 204 | + |
| 205 | + for idx, e in enumerate(loaded_examples): |
| 206 | + src_query_tokens = e['src_query_tokens'][:MAX_QUERY_LENGTH] |
| 207 | + tgt_actions = e['tgt_actions'] |
| 208 | + tgt_action_infos = Django.get_action_infos(src_query_tokens, tgt_actions) |
| 209 | + |
| 210 | + example = Example(idx=idx, |
| 211 | + src_sent=src_query_tokens, |
| 212 | + tgt_actions=tgt_action_infos, |
| 213 | + tgt_code=e['tgt_canonical_code'], |
| 214 | + tgt_ast=e['tgt_ast'], |
| 215 | + meta={'raw_code': e['raw_code'], 'str_map': e['str_map']}) |
| 216 | + |
| 217 | + print('second pass, processed %d' % idx, file=sys.stderr) |
| 218 | + |
| 219 | + action_len.append(len(tgt_action_infos)) |
| 220 | + |
| 221 | + # train, valid, test split |
| 222 | + if 0 <= idx < 16000: |
| 223 | + train_examples.append(example) |
| 224 | + elif 16000 <= idx < 17000: |
| 225 | + dev_examples.append(example) |
| 226 | + else: |
| 227 | + test_examples.append(example) |
| 228 | + |
| 229 | + print('Max action len: %d' % max(action_len), file=sys.stderr) |
| 230 | + print('Avg action len: %d' % np.average(action_len), file=sys.stderr) |
| 231 | + print('Actions larger than 100: %d' % len(filter(lambda x: x > 100, action_len)), file=sys.stderr) |
| 232 | + |
| 233 | + return (train_examples, dev_examples, test_examples), vocab |
| 234 | + |
| 235 | + @staticmethod |
| 236 | + def get_action_infos(src_query, tgt_actions): |
| 237 | + action_infos = [] |
| 238 | + hyp = Hypothesis() |
| 239 | + for t, action in enumerate(tgt_actions): |
| 240 | + action_info = ActionInfo(action) |
| 241 | + action_info.t = t |
| 242 | + if hyp.frontier_node: |
| 243 | + action_info.parent_t = hyp.frontier_node.created_time |
| 244 | + action_info.frontier_prod = hyp.frontier_node.production |
| 245 | + action_info.frontier_field = hyp.frontier_field.field |
| 246 | + |
| 247 | + if isinstance(action, GenTokenAction): |
| 248 | + try: |
| 249 | + tok_src_idx = src_query.index(str(action.token)) |
| 250 | + action_info.copy_from_src = True |
| 251 | + action_info.src_token_position = tok_src_idx |
| 252 | + except ValueError: |
| 253 | + pass |
| 254 | + |
| 255 | + hyp.apply_action(action) |
| 256 | + action_infos.append(action_info) |
| 257 | + |
| 258 | + return action_infos |
| 259 | + |
| 260 | + @staticmethod |
| 261 | + def generate_django_dataset(): |
| 262 | + annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' |
| 263 | + code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' |
| 264 | + |
| 265 | + (train, dev, test), vocab = Django.parse_django_dataset(annot_file, code_file, 'asdl/lang/py/py_asdl.txt') |
| 266 | + |
| 267 | + pickle.dump(train, open('data/django/train.bin', 'w')) |
| 268 | + pickle.dump(dev, open('data/django/dev.bin', 'w')) |
| 269 | + pickle.dump(test, open('data/django/test.bin', 'w')) |
| 270 | + pickle.dump(vocab, open('data/django/vocab.bin', 'w')) |
| 271 | + |
| 272 | + @staticmethod |
| 273 | + def run(): |
| 274 | + asdl_text = open('asdl/lang/py/py_asdl.txt').read() |
| 275 | + grammar = ASDLGrammar.from_text(asdl_text) |
| 276 | + |
| 277 | + annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' |
| 278 | + code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' |
| 279 | + |
| 280 | + transition_system = PythonTransitionSystem(grammar) |
| 281 | + |
| 282 | + for idx, (src_query, tgt_code) in enumerate(zip(open(annot_file), open(code_file))): |
| 283 | + src_query = src_query.strip() |
| 284 | + tgt_code = tgt_code.strip() |
| 285 | + |
| 286 | + query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(src_query, tgt_code) |
| 287 | + python_ast = ast.parse(tgt_canonical_code).body[0] |
| 288 | + gold_source = astor.to_source(python_ast) |
| 289 | + tgt_ast = python_ast_to_asdl_ast(python_ast, grammar) |
| 290 | + tgt_actions = transition_system.get_actions(tgt_ast) |
| 291 | + |
| 292 | + # sanity check |
| 293 | + hyp = Hypothesis() |
| 294 | + for action in tgt_actions: |
| 295 | + assert action.__class__ in transition_system.get_valid_continuation_types(hyp) |
| 296 | + if isinstance(action, ApplyRuleAction): |
| 297 | + assert action.production in transition_system.get_valid_continuating_productions(hyp) |
| 298 | + hyp.apply_action(action) |
| 299 | + |
| 300 | + src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar)) |
| 301 | + assert src_from_hyp == gold_source |
| 302 | + |
| 303 | + |
| 304 | + |
| 305 | +if __name__ == '__main__': |
| 306 | + Django.run() |
| 307 | + # f1 = Field('hahah', ASDLPrimitiveType('123'), 'single') |
| 308 | + # rf1 = RealizedField(f1, value=123) |
| 309 | + # |
| 310 | + # # print(f1 == rf1) |
| 311 | + # a = {f1: 1} |
| 312 | + # print(a[rf1]) |
| 313 | + # Django.generate_django_dataset() |
0 commit comments