Skip to content

Commit cfecde9

Browse files
committedJan 20, 2018
init
0 parents  commit cfecde9

25 files changed

+2847
-0
lines changed
 

‎__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# coding=utf-8

‎asdl/__init__.py

Whitespace-only changes.

‎asdl/asdl.py

+315
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# coding=utf-8
2+
from itertools import chain
3+
4+
import utils
5+
6+
7+
class ASDLGrammar(object):
8+
"""
9+
Collection of types, constructors and productions
10+
"""
11+
def __init__(self, productions):
12+
# productions are indexed by their head types
13+
self._productions = dict()
14+
self._constructor_production_map = dict()
15+
for prod in productions:
16+
if prod.type not in self._productions:
17+
self._productions[prod.type] = list()
18+
self._productions[prod.type].append(prod)
19+
self._constructor_production_map[prod.constructor.name] = prod
20+
21+
self.root_type = productions[0].type
22+
# number of constructors
23+
self.size = sum(len(head) for head in self._productions.itervalues())
24+
25+
# get entities to their ids map
26+
self.prod2id = {prod: i for i, prod in enumerate(self.productions)}
27+
self.type2id = {type: i for i, type in enumerate(self.types)}
28+
self.field2id = {field: i for i, field in enumerate(self.fields)}
29+
30+
self.id2prod = {i: prod for i, prod in enumerate(self.productions)}
31+
self.id2type = {i: type for i, type in enumerate(self.types)}
32+
self.id2field = {i: field for i, field in enumerate(self.fields)}
33+
34+
def __len__(self):
35+
return self.size
36+
37+
@property
38+
def productions(self):
39+
return sorted(chain.from_iterable(self._productions.itervalues()), key=lambda x: repr(x))
40+
41+
def __getitem__(self, datum):
42+
if isinstance(datum, str):
43+
return self._productions[ASDLType(datum)]
44+
elif isinstance(datum, ASDLType):
45+
return self._productions[datum]
46+
47+
def get_prod_by_ctr_name(self, name):
48+
return self._constructor_production_map[name]
49+
50+
@property
51+
def types(self):
52+
if not hasattr(self, '_types'):
53+
all_types = set()
54+
for prod in self.productions:
55+
all_types.add(prod.type)
56+
all_types.update(map(lambda x: x.type, prod.constructor.fields))
57+
58+
self._types = sorted(all_types, key=lambda x: x.name)
59+
60+
return self._types
61+
62+
@property
63+
def fields(self):
64+
if not hasattr(self, '_fields'):
65+
all_fields = set()
66+
for prod in self.productions:
67+
all_fields.update(prod.constructor.fields)
68+
69+
self._fields = sorted(all_fields, key=lambda x: x.name)
70+
71+
return self._fields
72+
73+
@property
74+
def primitive_types(self):
75+
return filter(lambda x: isinstance(x, ASDLPrimitiveType), self.types)
76+
77+
@property
78+
def composite_types(self):
79+
return filter(lambda x: isinstance(x, ASDLCompositeType), self.types)
80+
81+
def is_composite_type(self, asdl_type):
82+
return asdl_type in self.composite_types
83+
84+
def is_primitive_type(self, asdl_type):
85+
return asdl_type in self.primitive_types
86+
87+
@staticmethod
88+
def from_text(text):
89+
def _parse_field_from_text(_text):
90+
d = _text.strip().split(' ')
91+
name = d[1].strip()
92+
type_str = d[0].strip()
93+
cardinality = 'single'
94+
if type_str[-1] == '*':
95+
type_str = type_str[:-1]
96+
cardinality = 'multiple'
97+
elif type_str[-1] == '?':
98+
type_str = type_str[:-1]
99+
cardinality = 'optional'
100+
101+
if type_str in primitive_type_names:
102+
return Field(name, ASDLPrimitiveType(type_str), cardinality=cardinality)
103+
else:
104+
return Field(name, ASDLCompositeType(type_str), cardinality=cardinality)
105+
106+
def _parse_constructor_from_text(_text):
107+
_text = _text.strip()
108+
fields = None
109+
if '(' in _text:
110+
name = _text[:_text.find('(')]
111+
field_blocks = _text[_text.find('(') + 1:_text.find(')')].split(',')
112+
fields = map(_parse_field_from_text, field_blocks)
113+
else:
114+
name = _text
115+
116+
if name == '': name = None
117+
118+
return ASDLConstructor(name, fields)
119+
120+
lines = utils.remove_comment(text).split('\n')
121+
lines = map(lambda l: l.strip(), lines)
122+
lines = filter(lambda l: l, lines)
123+
line_no = 0
124+
125+
# first line is always the primitive types
126+
primitive_type_names = map(lambda x: x.strip(), lines[line_no].split(','))
127+
line_no += 1
128+
129+
all_productions = list()
130+
131+
while True:
132+
type_block = lines[line_no]
133+
type_name = type_block[:type_block.find('=')].strip()
134+
constructors_blocks = type_block[type_block.find('=') + 1:].split('|')
135+
i = line_no + 1
136+
while i < len(lines) and lines[i].strip().startswith('|'):
137+
t = lines[i].strip()
138+
cont_constructors_blocks = t[1:].split('|')
139+
constructors_blocks.extend(cont_constructors_blocks)
140+
141+
i += 1
142+
143+
constructors_blocks = filter(lambda x: x and x.strip(), constructors_blocks)
144+
145+
# parse type name
146+
new_type = ASDLPrimitiveType(type_name) if type_name in primitive_type_names else ASDLCompositeType(type_name)
147+
constructors = map(_parse_constructor_from_text, constructors_blocks)
148+
149+
productions = map(lambda c: ASDLProduction(new_type, c), constructors)
150+
all_productions.extend(productions)
151+
152+
line_no = i
153+
if line_no == len(lines):
154+
break
155+
156+
grammar = ASDLGrammar(all_productions)
157+
grammar.primitive_types
158+
159+
return grammar
160+
161+
162+
class ASDLProduction(object):
163+
def __init__(self, type, constructor):
164+
self.type = type
165+
self.constructor = constructor
166+
167+
@property
168+
def fields(self):
169+
return self.constructor.fields
170+
171+
def __getitem__(self, field_name):
172+
return self.constructor[field_name]
173+
174+
def __hash__(self):
175+
h = hash(self.type) ^ hash(self.constructor)
176+
177+
return h
178+
179+
def __eq__(self, other):
180+
return isinstance(other, ASDLProduction) and \
181+
self.type == other.type and \
182+
self.constructor == other.constructor
183+
184+
def __ne__(self, other):
185+
return not self.__eq__(other)
186+
187+
def __repr__(self):
188+
return '%s -> %s' % (self.type.__repr__(plain=True), self.constructor.__repr__(plain=True))
189+
190+
191+
class ASDLConstructor(object):
192+
def __init__(self, name, fields=None):
193+
self.name = name
194+
self.fields = []
195+
if fields:
196+
self.fields = list(fields)
197+
198+
def __getitem__(self, field_name):
199+
for field in self.fields:
200+
if field.name == field_name: return field
201+
202+
raise KeyError
203+
204+
def __hash__(self):
205+
h = hash(self.name)
206+
for field in self.fields:
207+
h ^= hash(field)
208+
209+
return h
210+
211+
def __eq__(self, other):
212+
return isinstance(other, ASDLConstructor) and \
213+
self.name == other.name and \
214+
self.fields == other.fields
215+
216+
def __ne__(self, other):
217+
return not self.__eq__(other)
218+
219+
def __repr__(self, plain=False):
220+
plain_repr = '%s(%s)' % (self.name,
221+
', '.join(f.__repr__(plain=True) for f in self.fields))
222+
if plain: return plain_repr
223+
else: return 'Constructor(%s)' % plain_repr
224+
225+
226+
class Field(object):
227+
def __init__(self, name, type, cardinality):
228+
self.name = name
229+
self.type = type
230+
231+
assert cardinality in ['single', 'optional', 'multiple']
232+
self.cardinality = cardinality
233+
234+
def __hash__(self):
235+
h = hash(self.name) ^ hash(self.type)
236+
h ^= hash(self.cardinality)
237+
238+
return h
239+
240+
def __eq__(self, other):
241+
return isinstance(other, Field) and \
242+
self.name == other.name and \
243+
self.type == other.type and \
244+
self.cardinality == other.cardinality
245+
246+
def __ne__(self, other):
247+
return not self.__eq__(other)
248+
249+
def __repr__(self, plain=False):
250+
plain_repr = '%s%s %s' % (self.type.__repr__(plain=True),
251+
Field.get_cardinality_repr(self.cardinality),
252+
self.name)
253+
if plain: return plain_repr
254+
else: return 'Field(%s)' % plain_repr
255+
256+
@staticmethod
257+
def get_cardinality_repr(cardinality):
258+
return '' if cardinality == 'single' else '?' if cardinality == 'optional' else '*'
259+
260+
261+
class ASDLType(object):
262+
def __init__(self, type_name):
263+
self.name = type_name
264+
265+
def __hash__(self):
266+
return hash(self.name)
267+
268+
def __eq__(self, other):
269+
return isinstance(other, ASDLType) and self.name == other.name
270+
271+
def __ne__(self, other):
272+
return not self.__eq__(other)
273+
274+
def __repr__(self, plain=False):
275+
plain_repr = self.name
276+
if plain: return plain_repr
277+
else: return '%s(%s)' % (self.__class__.__name__, plain_repr)
278+
279+
280+
class ASDLCompositeType(ASDLType):
281+
pass
282+
283+
284+
class ASDLPrimitiveType(ASDLType):
285+
pass
286+
287+
288+
if __name__ == '__main__':
289+
asdl_desc = """
290+
var, ent, num, var_type
291+
292+
expr = Variable(var variable)
293+
| Entity(ent entity)
294+
| Number(num number)
295+
| Apply(pred predicate, expr* arguments)
296+
| Argmax(var variable, expr domain, expr body)
297+
| Argmin(var variable, expr domain, expr body)
298+
| Count(var variable, expr body)
299+
| Exists(var variable, expr body)
300+
| Lambda(var variable, var_type type, expr body)
301+
| Max(var variable, expr body)
302+
| Min(var variable, expr body)
303+
| Sum(var variable, expr domain, expr body)
304+
| The(var variable, expr body)
305+
| Not(expr argument)
306+
| And(expr* arguments)
307+
| Or(expr* arguments)
308+
| Compare(cmp_op op, expr left, expr right)
309+
310+
cmp_op = GreaterThan | Equal | LessThan
311+
"""
312+
313+
grammar = ASDLGrammar.from_text(asdl_desc)
314+
print(ASDLCompositeType('1') == ASDLPrimitiveType('1'))
315+

‎asdl/asdl_ast.py

+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# coding=utf-8
2+
3+
from cStringIO import StringIO
4+
5+
from asdl import *
6+
from asdl import Field
7+
8+
9+
class AbstractSyntaxTree(object):
10+
def __init__(self, production, realized_fields=None):
11+
self.production = production
12+
13+
# a child is essentially a *realized_field*
14+
self.fields = []
15+
16+
# record its parent field to which it's attached
17+
self.parent_field = None
18+
19+
# used in decoding, record the time step when this node was created
20+
self.created_time = 0
21+
22+
if realized_fields:
23+
assert len(realized_fields) == len(self.production.fields)
24+
25+
for field in realized_fields:
26+
self.add_child(field)
27+
else:
28+
for field in self.production.fields:
29+
self.add_child(RealizedField(field))
30+
31+
def add_child(self, realized_field):
32+
# if isinstance(realized_field.value, AbstractSyntaxTree):
33+
# realized_field.value.parent = self
34+
self.fields.append(realized_field)
35+
realized_field.parent_node = self
36+
37+
def sanity_check(self):
38+
if len(self.production.fields) != len(self.fields):
39+
raise ValueError('filed number must match')
40+
for child in self.fields:
41+
if isinstance(child.value, AbstractSyntaxTree):
42+
child.value.sanity_check()
43+
44+
def copy(self):
45+
new_tree = AbstractSyntaxTree(self.production)
46+
new_tree.created_time = self.created_time
47+
for i, old_field in enumerate(self.fields):
48+
new_field = new_tree.fields[i]
49+
if isinstance(old_field.type, ASDLCompositeType):
50+
for value in old_field.as_value_list:
51+
new_field.add_value(value.copy())
52+
else:
53+
for value in old_field.as_value_list:
54+
new_field.add_value(value)
55+
56+
return new_tree
57+
58+
def to_string(self, sb=None):
59+
is_root = False
60+
if sb is None:
61+
is_root = True
62+
sb = StringIO()
63+
64+
sb.write('(')
65+
sb.write(self.production.constructor.name)
66+
67+
for field in self.fields:
68+
sb.write(' ')
69+
sb.write('(')
70+
sb.write(field.type.name)
71+
sb.write(Field.get_cardinality_repr(field.cardinality))
72+
sb.write('-')
73+
sb.write(field.name)
74+
75+
if field.value is not None:
76+
for val_node in field.as_value_list:
77+
sb.write(' ')
78+
if isinstance(field.type, ASDLCompositeType):
79+
val_node.to_string(sb)
80+
else:
81+
sb.write(str(val_node).replace(' ', '-SPACE-'))
82+
83+
sb.write(')') # of field
84+
85+
sb.write(')') # of node
86+
87+
if is_root:
88+
return sb.getvalue()
89+
90+
def __hash__(self):
91+
code = hash(self.production)
92+
for field in self.fields:
93+
code = code + 37 * hash(field)
94+
95+
return code
96+
97+
def __eq__(self, other):
98+
if not isinstance(other, self.__class__):
99+
return False
100+
101+
if self.production != other.production:
102+
return False
103+
104+
if len(self.fields) != len(other.fields):
105+
return False
106+
107+
for i in xrange(len(self.fields)):
108+
if self.fields[i] != other.fields[i]: return False
109+
110+
return True
111+
112+
def __ne__(self, other):
113+
return not self.__eq__(other)
114+
115+
def __repr__(self):
116+
return repr(self.production)
117+
118+
119+
class RealizedField(Field):
120+
"""wrapper of field realized with values"""
121+
def __init__(self, field, value=None, parent=None):
122+
super(RealizedField, self).__init__(field.name, field.type, field.cardinality)
123+
124+
# record its parent AST node
125+
self.parent_node = None
126+
127+
# FIXME: hack, return the field as a property
128+
self.field = field
129+
130+
# initialize value to correct type
131+
if self.cardinality == 'multiple':
132+
self.value = []
133+
if value:
134+
for child_node in value:
135+
self.add_value(child_node)
136+
else:
137+
self.value = None
138+
if value: self.add_value(value)
139+
140+
# properties only used in decoding, record if the field is finished generating
141+
# when card in [optional, multiple]
142+
self._not_single_cardinality_finished = False
143+
144+
def add_value(self, value):
145+
if isinstance(value, AbstractSyntaxTree):
146+
value.parent_field = self
147+
148+
if self.cardinality == 'multiple':
149+
self.value.append(value)
150+
else:
151+
self.value = value
152+
153+
@property
154+
def as_value_list(self):
155+
"""get value as an iterable"""
156+
if self.cardinality == 'multiple': return self.value
157+
elif self.value is not None: return [self.value]
158+
else: return []
159+
160+
@property
161+
def finished(self):
162+
if self.cardinality == 'single':
163+
if self.value is None: return False
164+
else: return True
165+
elif self.cardinality == 'optional' and self.value is not None:
166+
return True
167+
else:
168+
if self._not_single_cardinality_finished: return True
169+
else: return False
170+
171+
def set_finish(self):
172+
# assert self.cardinality in ('optional', 'multiple')
173+
self._not_single_cardinality_finished = True
174+
175+
def __eq__(self, other):
176+
if super(RealizedField, self).__eq__(other):
177+
if type(other) == Field: return True # FIXME: hack, Field and RealizedField can compare!
178+
if self.value == other.value: return True
179+
else: return False
180+
else: return False

‎asdl/hypothesis.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# coding=utf-8
2+
3+
from asdl import *
4+
from asdl_ast import AbstractSyntaxTree
5+
from transition_system import *
6+
7+
8+
class Hypothesis(object):
9+
def __init__(self):
10+
self.tree = None
11+
self.actions = []
12+
self.score = 0.
13+
self.frontier_node = None
14+
self.frontier_field = None
15+
self._value_buffer = []
16+
17+
# record the current time step
18+
self.t = 0
19+
20+
def apply_action(self, action):
21+
if self.tree is None:
22+
assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \
23+
'at the beginning of decoding'
24+
25+
self.tree = AbstractSyntaxTree(action.production)
26+
self.update_frontier_info()
27+
elif self.frontier_node:
28+
if isinstance(self.frontier_field.type, ASDLCompositeType):
29+
if isinstance(action, ApplyRuleAction):
30+
field_value = AbstractSyntaxTree(action.production)
31+
field_value.created_time = self.t
32+
self.frontier_field.add_value(field_value)
33+
self.update_frontier_info()
34+
elif isinstance(action, ReduceAction):
35+
assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
36+
'applied on field with multiple ' \
37+
'cardinality'
38+
self.frontier_field.set_finish()
39+
self.update_frontier_info()
40+
else:
41+
raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field))
42+
else: # fill in a primitive field
43+
if isinstance(action, GenTokenAction):
44+
# only field of type string requires termination signal </primitive>
45+
end_primitive = False
46+
if self.frontier_field.type.name == 'string':
47+
if action.is_stop_signal():
48+
self.frontier_field.add_value(' '.join(self._value_buffer))
49+
self._value_buffer = []
50+
51+
end_primitive = True
52+
else:
53+
self._value_buffer.append(action.token)
54+
else:
55+
self.frontier_field.add_value(action.token)
56+
end_primitive = True
57+
58+
if end_primitive and self.frontier_field.cardinality in ('single', 'optional'):
59+
self.frontier_field.set_finish()
60+
self.update_frontier_info()
61+
62+
elif isinstance(action, ReduceAction):
63+
assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
64+
'applied on field with multiple ' \
65+
'cardinality'
66+
self.frontier_field.set_finish()
67+
self.update_frontier_info()
68+
else:
69+
raise ValueError('Can only invoke GenToken or Reduce actions on primitive fields')
70+
71+
self.t += 1
72+
self.actions.append(action)
73+
74+
def update_frontier_info(self):
75+
def _find_frontier_node_and_field(tree_node):
76+
if tree_node:
77+
for field in tree_node.fields:
78+
# if it's an intermediate node, check its children
79+
if isinstance(field.type, ASDLCompositeType) and field.value:
80+
if field.cardinality in ('single', 'optional'): iter_values = [field.value]
81+
else: iter_values = field.value
82+
83+
for child_node in iter_values:
84+
result = _find_frontier_node_and_field(child_node)
85+
if result: return result
86+
87+
# now all its possible children are checked
88+
if not field.finished:
89+
return tree_node, field
90+
91+
return None
92+
else: return None
93+
94+
frontier_info = _find_frontier_node_and_field(self.tree)
95+
if frontier_info:
96+
self.frontier_node, self.frontier_field = frontier_info
97+
98+
@property
99+
def completed(self):
100+
return self.tree and self.frontier_field is None

‎asdl/lang/__init__.py

Whitespace-only changes.

‎asdl/lang/py/__init__.py

Whitespace-only changes.

‎asdl/lang/py/dataset.py

+313
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# coding=utf-8
2+
3+
from __future__ import print_function
4+
5+
import re
6+
import cPickle as pickle
7+
import ast
8+
import astor
9+
import nltk
10+
import sys
11+
12+
import numpy as np
13+
14+
from asdl.asdl_ast import RealizedField
15+
from asdl.lang.py.py_asdl_helper import python_ast_to_asdl_ast, asdl_ast_to_python_ast
16+
from asdl.lang.py.py_transition_system import PythonTransitionSystem
17+
from asdl.hypothesis import *
18+
19+
from components.action_info import ActionInfo
20+
21+
p_elif = re.compile(r'^elif\s?')
22+
p_else = re.compile(r'^else\s?')
23+
p_try = re.compile(r'^try\s?')
24+
p_except = re.compile(r'^except\s?')
25+
p_finally = re.compile(r'^finally\s?')
26+
p_decorator = re.compile(r'^@.*')
27+
28+
QUOTED_STRING_RE = re.compile(r"(?P<quote>['\"])(?P<string>.*?)(?<!\\)(?P=quote)")
29+
30+
31+
class Django(object):
32+
@staticmethod
33+
def canonicalize_code(code):
34+
if p_elif.match(code):
35+
code = 'if True: pass\n' + code
36+
37+
if p_else.match(code):
38+
code = 'if True: pass\n' + code
39+
40+
if p_try.match(code):
41+
code = code + 'pass\nexcept: pass'
42+
elif p_except.match(code):
43+
code = 'try: pass\n' + code
44+
elif p_finally.match(code):
45+
code = 'try: pass\n' + code
46+
47+
if p_decorator.match(code):
48+
code = code + '\ndef dummy(): pass'
49+
50+
if code[-1] == ':':
51+
code = code + 'pass'
52+
53+
return code
54+
55+
@staticmethod
56+
def canonicalize_query(query):
57+
"""
58+
canonicalize the query, replace strings to a special place holder
59+
"""
60+
str_count = 0
61+
str_map = dict()
62+
63+
matches = QUOTED_STRING_RE.findall(query)
64+
# de-duplicate
65+
cur_replaced_strs = set()
66+
for match in matches:
67+
# If one or more groups are present in the pattern,
68+
# it returns a list of groups
69+
quote = match[0]
70+
str_literal = quote + match[1] + quote
71+
72+
if str_literal in cur_replaced_strs:
73+
continue
74+
75+
# FIXME: substitute the ' % s ' with
76+
if str_literal in ['\'%s\'', '\"%s\"']:
77+
continue
78+
79+
str_repr = '_STR:%d_' % str_count
80+
str_map[str_literal] = str_repr
81+
82+
query = query.replace(str_literal, str_repr)
83+
84+
str_count += 1
85+
cur_replaced_strs.add(str_literal)
86+
87+
# tokenize
88+
query_tokens = nltk.word_tokenize(query)
89+
90+
new_query_tokens = []
91+
# break up function calls like foo.bar.func
92+
for token in query_tokens:
93+
new_query_tokens.append(token)
94+
i = token.find('.')
95+
if 0 < i < len(token) - 1:
96+
new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']']
97+
new_query_tokens.extend(new_tokens)
98+
99+
query = ' '.join(new_query_tokens)
100+
101+
return query, str_map
102+
103+
@staticmethod
104+
def canonicalize_example(query, code):
105+
106+
canonical_query, str_map = Django.canonicalize_query(query)
107+
query_tokens = canonical_query.split(' ')
108+
canonical_code = code
109+
110+
for str_literal, str_repr in str_map.iteritems():
111+
canonical_code = canonical_code.replace(str_literal, '\'' + str_repr + '\'')
112+
113+
canonical_code = Django.canonicalize_code(canonical_code)
114+
115+
# sanity check
116+
try:
117+
gold_ast_tree = ast.parse(canonical_code).body[0]
118+
except:
119+
print('error!')
120+
canonical_code = Django.canonicalize_code(code)
121+
gold_ast_tree = ast.parse(canonical_code).body[0]
122+
str_map = {}
123+
124+
# parse_tree = python_ast_to_asdl_ast(gold_ast_tree, grammar)
125+
# gold_source = astor.to_source(gold_ast_tree)
126+
# ast_tree = asdl_ast_to_python_ast(parse_tree, grammar)
127+
# source = astor.to_source(ast_tree)
128+
129+
# assert gold_source == source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, source)
130+
#
131+
# # action check
132+
# parser = PythonTransitionSystem(grammar)
133+
# actions = parser.get_actions(parse_tree)
134+
#
135+
# hyp = Hypothesis()
136+
# for action in actions:
137+
# assert action.__class__ in parser.get_valid_continuation_types(hyp)
138+
# if isinstance(action, ApplyRuleAction):
139+
# assert action in parser.get_valid_continuations(hyp)
140+
# hyp.apply_action(action)
141+
#
142+
# src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar))
143+
# assert src_from_hyp == gold_source
144+
145+
return query_tokens, canonical_code, str_map
146+
147+
@staticmethod
148+
def parse_django_dataset(annot_file, code_file, asdl_file_path, MAX_QUERY_LENGTH=70):
149+
asdl_text = open(asdl_file_path).read()
150+
grammar = ASDLGrammar.from_text(asdl_text)
151+
transition_system = PythonTransitionSystem(grammar)
152+
153+
loaded_examples = []
154+
155+
from components.vocab import Vocab, VocabEntry
156+
from components.dataset import Example
157+
158+
for idx, (src_query, tgt_code) in enumerate(zip(open(annot_file), open(code_file))):
159+
src_query = src_query.strip()
160+
tgt_code = tgt_code.strip()
161+
162+
src_query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(src_query, tgt_code)
163+
python_ast = ast.parse(tgt_canonical_code).body[0]
164+
gold_source = astor.to_source(python_ast)
165+
tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
166+
tgt_actions = transition_system.get_actions(tgt_ast)
167+
168+
# sanity check
169+
hyp = Hypothesis()
170+
for action in tgt_actions:
171+
assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
172+
if isinstance(action, ApplyRuleAction):
173+
assert action in transition_system.get_valid_continuating_productions(hyp)
174+
hyp.apply_action(action)
175+
176+
src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar))
177+
assert src_from_hyp == gold_source
178+
179+
loaded_examples.append({'src_query_tokens': src_query_tokens,
180+
'tgt_canonical_code': tgt_canonical_code,
181+
'tgt_ast': tgt_ast,
182+
'tgt_actions': tgt_actions,
183+
'raw_code': tgt_code, 'str_map': str_map})
184+
185+
print('first pass, processed %d' % idx, file=sys.stderr)
186+
187+
src_vocab = VocabEntry.from_corpus([e['src_query_tokens'] for e in loaded_examples], size=5000, freq_cutoff=3)
188+
189+
primitive_tokens = [map(lambda a: a.token,
190+
filter(lambda a: isinstance(a, GenTokenAction), e['tgt_actions']))
191+
for e in loaded_examples]
192+
193+
primitive_vocab = VocabEntry.from_corpus(primitive_tokens, size=5000, freq_cutoff=3)
194+
assert '_STR:0_' in primitive_vocab
195+
196+
vocab = Vocab(source=src_vocab, primitive=primitive_vocab)
197+
print('generated vocabulary %s' % repr(vocab), file=sys.stderr)
198+
199+
train_examples = []
200+
dev_examples = []
201+
test_examples = []
202+
203+
action_len = []
204+
205+
for idx, e in enumerate(loaded_examples):
206+
src_query_tokens = e['src_query_tokens'][:MAX_QUERY_LENGTH]
207+
tgt_actions = e['tgt_actions']
208+
tgt_action_infos = Django.get_action_infos(src_query_tokens, tgt_actions)
209+
210+
example = Example(idx=idx,
211+
src_sent=src_query_tokens,
212+
tgt_actions=tgt_action_infos,
213+
tgt_code=e['tgt_canonical_code'],
214+
tgt_ast=e['tgt_ast'],
215+
meta={'raw_code': e['raw_code'], 'str_map': e['str_map']})
216+
217+
print('second pass, processed %d' % idx, file=sys.stderr)
218+
219+
action_len.append(len(tgt_action_infos))
220+
221+
# train, valid, test split
222+
if 0 <= idx < 16000:
223+
train_examples.append(example)
224+
elif 16000 <= idx < 17000:
225+
dev_examples.append(example)
226+
else:
227+
test_examples.append(example)
228+
229+
print('Max action len: %d' % max(action_len), file=sys.stderr)
230+
print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
231+
print('Actions larger than 100: %d' % len(filter(lambda x: x > 100, action_len)), file=sys.stderr)
232+
233+
return (train_examples, dev_examples, test_examples), vocab
234+
235+
@staticmethod
236+
def get_action_infos(src_query, tgt_actions):
237+
action_infos = []
238+
hyp = Hypothesis()
239+
for t, action in enumerate(tgt_actions):
240+
action_info = ActionInfo(action)
241+
action_info.t = t
242+
if hyp.frontier_node:
243+
action_info.parent_t = hyp.frontier_node.created_time
244+
action_info.frontier_prod = hyp.frontier_node.production
245+
action_info.frontier_field = hyp.frontier_field.field
246+
247+
if isinstance(action, GenTokenAction):
248+
try:
249+
tok_src_idx = src_query.index(str(action.token))
250+
action_info.copy_from_src = True
251+
action_info.src_token_position = tok_src_idx
252+
except ValueError:
253+
pass
254+
255+
hyp.apply_action(action)
256+
action_infos.append(action_info)
257+
258+
return action_infos
259+
260+
@staticmethod
261+
def generate_django_dataset():
262+
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
263+
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
264+
265+
(train, dev, test), vocab = Django.parse_django_dataset(annot_file, code_file, 'asdl/lang/py/py_asdl.txt')
266+
267+
pickle.dump(train, open('data/django/train.bin', 'w'))
268+
pickle.dump(dev, open('data/django/dev.bin', 'w'))
269+
pickle.dump(test, open('data/django/test.bin', 'w'))
270+
pickle.dump(vocab, open('data/django/vocab.bin', 'w'))
271+
272+
@staticmethod
273+
def run():
274+
asdl_text = open('asdl/lang/py/py_asdl.txt').read()
275+
grammar = ASDLGrammar.from_text(asdl_text)
276+
277+
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
278+
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
279+
280+
transition_system = PythonTransitionSystem(grammar)
281+
282+
for idx, (src_query, tgt_code) in enumerate(zip(open(annot_file), open(code_file))):
283+
src_query = src_query.strip()
284+
tgt_code = tgt_code.strip()
285+
286+
query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(src_query, tgt_code)
287+
python_ast = ast.parse(tgt_canonical_code).body[0]
288+
gold_source = astor.to_source(python_ast)
289+
tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
290+
tgt_actions = transition_system.get_actions(tgt_ast)
291+
292+
# sanity check
293+
hyp = Hypothesis()
294+
for action in tgt_actions:
295+
assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
296+
if isinstance(action, ApplyRuleAction):
297+
assert action.production in transition_system.get_valid_continuating_productions(hyp)
298+
hyp.apply_action(action)
299+
300+
src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar))
301+
assert src_from_hyp == gold_source
302+
303+
304+
305+
if __name__ == '__main__':
306+
Django.run()
307+
# f1 = Field('hahah', ASDLPrimitiveType('123'), 'single')
308+
# rf1 = RealizedField(f1, value=123)
309+
#
310+
# # print(f1 == rf1)
311+
# a = {f1: 1}
312+
# print(a[rf1])
313+
# Django.generate_django_dataset()

‎asdl/lang/py/py_asdl_helper.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# coding=utf-8
2+
3+
import sys
4+
5+
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
6+
7+
8+
# from https://stackoverflow.com/questions/15357422/python-determine-if-a-string-should-be-converted-into-int-or-float
9+
def isfloat(x):
10+
try:
11+
a = float(x)
12+
except ValueError:
13+
return False
14+
else:
15+
return True
16+
17+
18+
def isint(x):
19+
try:
20+
a = float(x)
21+
b = int(a)
22+
except ValueError:
23+
return False
24+
else:
25+
return a == b
26+
27+
28+
def python_ast_to_asdl_ast(py_ast_node, grammar):
29+
# node should be composite
30+
py_node_name = type(py_ast_node).__name__
31+
# assert py_node_name.startswith('_ast.')
32+
33+
production = grammar.get_prod_by_ctr_name(py_node_name)
34+
35+
fields = []
36+
for field in production.fields:
37+
field_value = getattr(py_ast_node, field.name)
38+
asdl_field = RealizedField(field)
39+
if field.cardinality == 'single' or field.cardinality == 'optional':
40+
if field_value is not None: # sometimes it could be 0
41+
if grammar.is_composite_type(field.type):
42+
asdl_field.value = python_ast_to_asdl_ast(field_value, grammar)
43+
else:
44+
asdl_field.value = field_value
45+
else:
46+
if field_value is not None:
47+
vals = []
48+
if grammar.is_composite_type(field.type):
49+
for val in field_value:
50+
child_node = python_ast_to_asdl_ast(val, grammar)
51+
vals.append(child_node)
52+
53+
asdl_field.value = vals
54+
else:
55+
asdl_field.value = str(field_value)
56+
57+
fields.append(asdl_field)
58+
59+
asdl_node = AbstractSyntaxTree(production, realized_fields=fields)
60+
61+
return asdl_node
62+
63+
64+
def asdl_ast_to_python_ast(asdl_ast_node, grammar):
65+
py_node_type = getattr(sys.modules['ast'], asdl_ast_node.production.constructor.name)
66+
py_ast_node = py_node_type()
67+
68+
for field in asdl_ast_node.fields:
69+
# for composite node
70+
field_value = None
71+
if grammar.is_composite_type(field.type):
72+
if field.value and field.cardinality == 'multiple':
73+
field_value = []
74+
for val in field.value:
75+
node = asdl_ast_to_python_ast(val, grammar)
76+
field_value.append(node)
77+
elif field.value and field.cardinality in ('single', 'optional'):
78+
field_value = asdl_ast_to_python_ast(field.value, grammar)
79+
else:
80+
# for primitive node
81+
if field.type.name == 'object':
82+
if isfloat(field.value):
83+
field_value = float(field.value)
84+
elif isint(field.value):
85+
field_value = int(field.value)
86+
else:
87+
raise ValueError('cannot convert [%s] to float or int' % field.value)
88+
elif field.type.name == 'int':
89+
field_value = int(field.value)
90+
else:
91+
field_value = field.value
92+
93+
# must set unused fields to default value...
94+
if field_value is None and field.cardinality == 'multiple':
95+
field_value = list()
96+
97+
setattr(py_ast_node, field.name, field_value)
98+
99+
return py_ast_node

‎asdl/lang/py/py_grammar.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# coding=utf-8
2+
3+
import ast
4+
5+
from asdl.asdl import ASDLGrammar
6+
from asdl.lang.py.py_asdl_helper import *
7+
from asdl.lang.py.py_transition_system import *
8+
9+
if __name__ == '__main__':
10+
asdl_text = open('py_asdl.txt').read()
11+
grammar = ASDLGrammar.from_text(asdl_text)
12+
py_code = 'sorted(mydict, key=mydict.get, reverse=True, how="hahaha", sadf=0.3)'
13+
#py_code = 'a = dict({a: None, b:False, s:"I love my mother", sd:124+3})'
14+
#py_code = '1e10'
15+
py_ast = ast.parse(py_code)
16+
asdl_ast = python_ast_to_asdl_ast(py_ast.body[0], grammar)
17+
py_ast_reconstructed = asdl_ast_to_python_ast(asdl_ast, grammar)
18+
19+
asdl_ast2 = asdl_ast.copy()
20+
assert asdl_ast == asdl_ast2
21+
del asdl_ast2
22+
23+
parser = PythonTransitionSystem(grammar)
24+
actions = parser.get_actions(asdl_ast)
25+
26+
from asdl.hypothesis import *
27+
hyp = Hypothesis()
28+
for action in actions:
29+
# assert action.__class__ in parser.get_valid_continuation_types(hyp)
30+
# if isinstance(action, ApplyRuleAction):
31+
# assert action.production in grammar[hyp.frontier_field.type]
32+
hyp.apply_action(action)
33+
34+
import astor
35+
src1 = astor.to_source(py_ast)
36+
src2 = astor.to_source(py_ast_reconstructed)
37+
src3 = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar))
38+
39+
print(src3)
40+
assert src1 == src3
41+
pass

‎asdl/lang/py/py_transition_system.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# coding=utf-8
2+
3+
from asdl.transition_system import TransitionSystem, GenTokenAction
4+
5+
6+
class PythonTransitionSystem(TransitionSystem):
7+
def get_primitive_field_actions(self, realized_field):
8+
actions = []
9+
if realized_field.value is not None:
10+
if realized_field.cardinality == 'multiple': # expr -> Global(identifier* names)
11+
field_values = realized_field.value
12+
else:
13+
field_values = [realized_field.value]
14+
15+
tokens = []
16+
if realized_field.type.name == 'string':
17+
for field_val in field_values:
18+
tokens.extend(field_val.split(' ') + ['</primitive>'])
19+
else:
20+
for field_val in field_values:
21+
tokens.append(field_val)
22+
23+
for tok in tokens:
24+
actions.append(GenTokenAction(tok))
25+
26+
return actions

‎asdl/logical_form.py

+242
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# coding=utf-8
2+
3+
from cStringIO import StringIO
4+
from collections import Iterable
5+
6+
from asdl import *
7+
from asdl_ast import AbstractSyntaxTree, RealizedField
8+
9+
10+
def parse_lambda_expr_helper(s, offset):
11+
if s[offset] != '(':
12+
name = ''
13+
while offset < len(s) and s[offset] != ' ':
14+
name += s[offset]
15+
offset += 1
16+
17+
node = Node(name)
18+
return node, offset
19+
else:
20+
# it's a sub-tree
21+
offset += 2
22+
name = ''
23+
while s[offset] != ' ':
24+
name += s[offset]
25+
offset += 1
26+
27+
node = Node(name)
28+
# extract its child nodes
29+
30+
while True:
31+
if s[offset] != ' ':
32+
raise ValueError('malformed string: node should have either had a '
33+
'close paren or a space at position %d' % offset)
34+
35+
offset += 1
36+
if s[offset] == ')':
37+
offset += 1
38+
return node, offset
39+
else:
40+
child_node, offset = parse_lambda_expr_helper(s, offset)
41+
42+
node.add_child(child_node)
43+
44+
45+
def parse_lambda_expr(s):
46+
return parse_lambda_expr_helper(s, 0)[0]
47+
48+
49+
class Node(object):
50+
def __init__(self, name, children=None):
51+
self.name = name
52+
self.parent = None
53+
self.children = list()
54+
if children:
55+
if isinstance(children, Iterable):
56+
for child in children:
57+
self.add_child(child)
58+
elif isinstance(children, Node):
59+
self.add_child(children)
60+
else: raise ValueError('Wrong type for child nodes')
61+
62+
def add_child(self, child):
63+
child.parent = self
64+
self.children.append(child)
65+
66+
def __hash__(self):
67+
code = hash(self.name)
68+
69+
for child in self.children:
70+
code = code * 37 + hash(child)
71+
72+
return code
73+
74+
def __eq__(self, other):
75+
if not isinstance(other, self.__class__):
76+
return False
77+
78+
if self.name != other.name:
79+
return False
80+
81+
if len(self.children) != len(other.children):
82+
return False
83+
84+
if self.name == 'and' or self.name == 'or':
85+
return sorted(self.children, key=lambda x: x.name) == sorted(other.children, key=lambda x: x.name)
86+
else:
87+
return self.children == other.children
88+
89+
def __ne__(self, other):
90+
return not self.__eq__(other)
91+
92+
def __repr__(self):
93+
return 'Node[%s, %d children]' % (self.name, len(self.children))
94+
95+
@property
96+
def is_leaf(self):
97+
return len(self.children) == 0
98+
99+
def to_string(self, sb=None):
100+
is_root = False
101+
if sb is None:
102+
is_root = True
103+
sb = StringIO()
104+
105+
if self.is_leaf:
106+
sb.write(self.name)
107+
else:
108+
sb.write('( ')
109+
sb.write(self.name)
110+
111+
for child in self.children:
112+
sb.write(' ')
113+
child.to_string(sb)
114+
115+
sb.write(' )')
116+
117+
if is_root:
118+
return sb.getvalue()
119+
120+
121+
def logical_form_to_ast(grammar, lf_node):
122+
if lf_node.name == 'lambda':
123+
# expr -> Lambda(var variable, var_type type, expr body)
124+
prod = grammar[('expr', 'Lambda')]
125+
126+
var_node = lf_node.children[0]
127+
var_field = RealizedField(prod['variable'], var_node.name)
128+
129+
var_type_node = lf_node.children[1]
130+
var_type_field = RealizedField(prod['type'], var_type_node.name)
131+
132+
body_node = lf_node.children[2]
133+
body_ast_node = logical_form_to_ast(grammar, body_node) # of type expr
134+
body_field = RealizedField(prod['body'], body_ast_node)
135+
136+
ast_node = AbstractSyntaxTree(prod,
137+
[var_field, var_type_field, body_field])
138+
elif lf_node.name == 'argmax' or lf_node.name == 'argmin':
139+
# expr -> Argmax(var variable, expr domain, expr body)
140+
if lf_node.name == 'argmax':
141+
prod = grammar[('expr', 'Argmax')]
142+
else:
143+
prod = grammar[('expr', 'Argmin')]
144+
145+
var_node = lf_node.children[0]
146+
var_field = RealizedField(prod['variable'], var_node.name)
147+
148+
domain_node = lf_node.children[2]
149+
domain_ast_node = logical_form_to_ast(grammar, domain_node)
150+
domain_field = RealizedField(prod['domain'], domain_ast_node)
151+
152+
body_node = lf_node.children[1]
153+
body_ast_node = logical_form_to_ast(grammar, body_node)
154+
body_field = RealizedField(prod['body'], body_ast_node)
155+
156+
ast_node = AbstractSyntaxTree(prod,
157+
[var_field, domain_field, body_field])
158+
elif lf_node.name == 'and' or lf_node.name == 'or':
159+
# expr -> And(expr* arguments) | Or(expr* arguments)
160+
if lf_node.name == 'and':
161+
prod = grammar[('expr', 'And')]
162+
else:
163+
prod = grammar[('expr', 'Or')]
164+
165+
arg_ast_nodes = []
166+
for arg_node in lf_node.children:
167+
arg_ast_node = logical_form_to_ast(grammar, arg_node)
168+
arg_ast_nodes.append(arg_ast_node)
169+
170+
ast_node = AbstractSyntaxTree(prod.constructor.name,
171+
RealizedField(prod['arguments'], arg_ast_nodes))
172+
elif lf_node.name == '>' or lf_node.name == '=' or lf_node.name == '<':
173+
# expr -> Compare(cmp_op op, expr left, expr right)
174+
prod = grammar[('expr', 'Compare')]
175+
op_name = 'GreaterThan' if lf_node.name == '>' else 'Equal' if lf_node.name == '=' else 'LessThan'
176+
op_field = RealizedField(prod['op'], AbstractSyntaxTree(grammar[('Compare', op_name)]))
177+
178+
left_node = lf_node.children[0]
179+
left_ast_node = logical_form_to_ast(grammar, left_node)
180+
left_field = RealizedField(prod['left'], left_ast_node)
181+
182+
right_node = lf_node.children[1]
183+
right_ast_node = logical_form_to_ast(grammar, right_node)
184+
right_field = RealizedField(prod['right'], right_ast_node)
185+
186+
ast_node = AbstractSyntaxTree(prod,
187+
[op_field, left_field, right_field])
188+
elif lf_node.name in ['flight', 'airline', 'from', 'to', 'day', 'month', 'arrival_time',
189+
'nonstop', 'has_meal', 'round_trip']:
190+
# expr -> Apply(pred predicate, expr* arguments)
191+
prod = grammar[('expr', 'Apply')]
192+
arg_ast_nodes = []
193+
for arg_node in lf_node.children:
194+
arg_ast_node = logical_form_to_ast(grammar, arg_node)
195+
arg_ast_nodes.append(arg_ast_node)
196+
197+
ast_node = AbstractSyntaxTree(prod,
198+
RealizedField(prod['arguments'], arg_ast_nodes))
199+
elif lf_node.name.startswith('$'):
200+
prod = grammar[('expr', 'Variable')]
201+
ast_node = AbstractSyntaxTree(prod,
202+
RealizedField(prod['variable'], lf_node.name))
203+
elif ':cl' in lf_node.name or ':pd' in lf_node.name or lf_node.name in ['ci0', 'ci1', 'ti0', 'ti1', 'da0', 'da1', 'al0']:
204+
prod = grammar[('expr', 'Entity')]
205+
ast_node = AbstractSyntaxTree(prod,
206+
RealizedField(prod['entity'], lf_node.name))
207+
else:
208+
raise NotImplementedError
209+
210+
return ast_node
211+
212+
213+
if __name__ == '__main__':
214+
asdl_desc = """
215+
var, ent, num, var_type
216+
217+
expr = Variable(var variable)
218+
| Entity(ent entity)
219+
| Number(num number)
220+
| Apply(pred predicate, expr* arguments)
221+
| Argmax(var variable, expr domain, expr body)
222+
| Argmin(var variable, expr domain, expr body)
223+
| Count(var variable, expr body)
224+
| Exists(var variable, expr body)
225+
| Lambda(var variable, var_type type, expr body)
226+
| Max(var variable, expr body)
227+
| Min(var variable, expr body)
228+
| Sum(var variable, expr domain, expr body)
229+
| The(var variable, expr body)
230+
| Not(expr argument)
231+
| And(expr* arguments)
232+
| Or(expr* arguments)
233+
| Compare(cmp_op op, expr left, expr right)
234+
235+
cmp_op = GreaterThan | Equal | LessThan
236+
"""
237+
238+
grammar = ASDLGrammar.from_text(asdl_desc)
239+
# lf = parse_lambda_expr('( lambda $0 e ( and ( flight $0 ) ( airline $0 al0 ) ( from $0 ci0 ) ( to $0 ci1 ) ) )')
240+
lf = parse_lambda_expr('al0')
241+
ast_tree = logical_form_to_ast(grammar, lf)
242+
pass

‎asdl/transition_system.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# coding=utf-8
2+
3+
4+
class Action(object):
5+
pass
6+
7+
8+
class ApplyRuleAction(Action):
9+
def __init__(self, production):
10+
self.production = production
11+
12+
def __hash__(self):
13+
return hash(self.production)
14+
15+
def __eq__(self, other):
16+
return isinstance(other, ApplyRuleAction) and self.production == other.production
17+
18+
def __ne__(self, other):
19+
return not self.__eq__(other)
20+
21+
22+
class GenTokenAction(Action):
23+
def __init__(self, token):
24+
self.token = token
25+
26+
def is_stop_signal(self):
27+
return self.token == '</primitive>'
28+
29+
30+
class ReduceAction(Action):
31+
pass
32+
33+
34+
class TransitionSystem(object):
35+
def __init__(self, grammar):
36+
self.grammar = grammar
37+
38+
def get_actions(self, asdl_ast):
39+
"""
40+
generate action sequence given the ASDL Syntax Tree
41+
"""
42+
43+
actions = []
44+
45+
parent_action = ApplyRuleAction(asdl_ast.production)
46+
actions.append(parent_action)
47+
48+
for field in asdl_ast.fields:
49+
# is a composite field
50+
if self.grammar.is_composite_type(field.type):
51+
if field.cardinality == 'single':
52+
field_actions = self.get_actions(field.value)
53+
else:
54+
field_actions = []
55+
56+
if field.value is not None:
57+
if field.cardinality == 'multiple':
58+
for val in field.value:
59+
cur_child_actions = self.get_actions(val)
60+
field_actions.extend(cur_child_actions)
61+
elif field.cardinality == 'optional':
62+
field_actions = self.get_actions(field.value)
63+
64+
# if an optional field is filled, then do not need Reduce action
65+
if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions:
66+
field_actions.append(ReduceAction())
67+
else: # is a primitive field
68+
field_actions = self.get_primitive_field_actions(field)
69+
70+
# if an optional field is filled, then do not need Reduce action
71+
if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions:
72+
# reduce action
73+
field_actions.append(ReduceAction())
74+
75+
actions.extend(field_actions)
76+
77+
return actions
78+
79+
def get_primitive_field_actions(self, realized_field):
80+
raise NotImplementedError
81+
82+
def get_valid_continuation_types(self, hyp):
83+
if hyp.tree:
84+
if self.grammar.is_composite_type(hyp.frontier_field.type):
85+
if hyp.frontier_field.cardinality == 'single':
86+
return ApplyRuleAction,
87+
else: # optional, multiple
88+
return ApplyRuleAction, ReduceAction
89+
else:
90+
if hyp.frontier_field.cardinality == 'single':
91+
return GenTokenAction,
92+
elif hyp.frontier_field.cardinality == 'optional':
93+
if hyp._value_buffer:
94+
return GenTokenAction,
95+
else:
96+
return GenTokenAction, ReduceAction
97+
else:
98+
return GenTokenAction, ReduceAction
99+
else:
100+
return ApplyRuleAction,
101+
102+
def get_valid_continuating_productions(self, hyp):
103+
if hyp.tree:
104+
if self.grammar.is_composite_type(hyp.frontier_field.type):
105+
return self.grammar[hyp.frontier_field.type]
106+
else:
107+
raise ValueError
108+
else:
109+
return self.grammar[self.grammar.root_type]

‎asdl/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
import re
3+
4+
5+
def remove_comment(text):
6+
text = re.sub(re.compile("#.*"), "", text)
7+
text = '\n'.join(filter(lambda x: x, text.split('\n')))
8+
9+
return text

‎components/__init__.py

Whitespace-only changes.

‎components/action_info.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# coding=utf-8
2+
3+
4+
class ActionInfo(object):
5+
"""sufficient statistics for making a prediction of an action at a time step"""
6+
7+
def __init__(self, action):
8+
self.t = 0
9+
self.parent_t = -1
10+
self.action = action
11+
self.frontier_prod = None
12+
self.frontier_field = None
13+
14+
# for GenToken actions only
15+
self.copy_from_src = False
16+
self.src_token_position = -1

‎components/dataset.py

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# coding=utf-8
2+
3+
import torch
4+
import numpy as np
5+
import cPickle as pickle
6+
from torch.autograd import Variable
7+
8+
from asdl.transition_system import ApplyRuleAction, ReduceAction
9+
from components.utils import cached_property
10+
11+
from model import nn_utils
12+
13+
14+
class Dataset(object):
15+
def __init__(self, examples):
16+
self.examples = examples
17+
18+
@property
19+
def all_source(self):
20+
return [e.src_sent for e in self.examples]
21+
22+
@property
23+
def all_targets(self):
24+
return [e.tgt_code for e in self.examples]
25+
26+
@staticmethod
27+
def from_bin_file(file_path):
28+
examples = pickle.load(open(file_path, 'rb'))
29+
return Dataset(examples)
30+
31+
def batch_iter(self, batch_size, shuffle=False):
32+
index_arr = np.arange(len(self.examples))
33+
if shuffle:
34+
np.random.shuffle(index_arr)
35+
36+
batch_num = int(np.ceil(len(self.examples) / float(batch_size)))
37+
for batch_id in xrange(batch_num):
38+
batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)]
39+
batch_examples = [self.examples[i] for i in batch_ids]
40+
batch_examples.sort(key=lambda e: -len(e.src_sent))
41+
42+
yield batch_examples
43+
44+
def __len__(self):
45+
return len(self.examples)
46+
47+
48+
class Example(object):
49+
def __init__(self, src_sent, tgt_actions, tgt_code, tgt_ast, idx=0, meta=None):
50+
self.src_sent = src_sent
51+
self.tgt_code = tgt_code
52+
self.tgt_ast = tgt_ast
53+
self.tgt_actions = tgt_actions
54+
55+
self.idx = idx
56+
self.meta = meta
57+
58+
59+
class Batch(object):
60+
def __init__(self, examples, grammar, vocab, cuda=False):
61+
self.examples = examples
62+
self.max_action_num = max(len(e.tgt_actions) for e in self.examples)
63+
64+
self.src_sents = [e.src_sent for e in self.examples]
65+
self.src_sents_len = [len(e.src_sent) for e in self.examples]
66+
67+
self.grammar = grammar
68+
self.vocab = vocab
69+
self.cuda = cuda
70+
71+
self.init_index_tensors()
72+
73+
def __len__(self):
74+
return len(self.examples)
75+
76+
def get_frontier_field_idx(self, t):
77+
ids = []
78+
for e in self.examples:
79+
if t < len(e.tgt_actions):
80+
ids.append(self.grammar.field2id[e.tgt_actions[t].frontier_field])
81+
# assert self.grammar.id2field[ids[-1]] == e.tgt_actions[t].frontier_field
82+
else:
83+
ids.append(0)
84+
85+
return Variable(torch.cuda.LongTensor(ids)) if self.cuda else Variable(torch.LongTensor(ids))
86+
87+
def get_frontier_prod_idx(self, t):
88+
ids = []
89+
for e in self.examples:
90+
if t < len(e.tgt_actions):
91+
ids.append(self.grammar.prod2id[e.tgt_actions[t].frontier_prod])
92+
# assert self.grammar.id2prod[ids[-1]] == e.tgt_actions[t].frontier_prod
93+
else:
94+
ids.append(0)
95+
96+
return Variable(torch.cuda.LongTensor(ids)) if self.cuda else Variable(torch.LongTensor(ids))
97+
98+
def get_frontier_field_type_idx(self, t):
99+
ids = []
100+
for e in self.examples:
101+
if t < len(e.tgt_actions):
102+
ids.append(self.grammar.type2id[e.tgt_actions[t].frontier_field.type])
103+
# assert self.grammar.id2type[ids[-1]] == e.tgt_actions[t].frontier_field.type
104+
else:
105+
ids.append(0)
106+
107+
return Variable(torch.cuda.LongTensor(ids)) if self.cuda else Variable(torch.LongTensor(ids))
108+
109+
def init_index_tensors(self):
110+
self.apply_rule_idx_matrix = []
111+
self.apply_rule_mask = []
112+
self.primitive_idx_matrix = []
113+
self.gen_token_mask = []
114+
self.primitive_copy_pos_matrix = []
115+
self.primitive_copy_mask = []
116+
117+
for t in xrange(self.max_action_num):
118+
app_rule_idx_row = []
119+
app_rule_mask_row = []
120+
token_row = []
121+
gen_token_mask_row = []
122+
copy_pos_row = []
123+
copy_mask_row = []
124+
125+
for e in self.examples:
126+
app_rule_idx = app_rule_mask = token_idx = gen_token_mask = copy_pos = copy_mask = 0
127+
if t < len(e.tgt_actions):
128+
action = e.tgt_actions[t].action
129+
action_info = e.tgt_actions[t]
130+
if isinstance(action, ApplyRuleAction):
131+
app_rule_idx = self.grammar.prod2id[action.production]
132+
# assert self.grammar.id2prod[app_rule_idx] == action.production
133+
app_rule_mask = 1
134+
elif isinstance(action, ReduceAction):
135+
app_rule_idx = len(self.grammar)
136+
app_rule_mask = 1
137+
else:
138+
token_idx = self.vocab.primitive[action.token]
139+
# cannot copy, only generation
140+
# could be unk!
141+
if not action_info.copy_from_src:
142+
gen_token_mask = 1
143+
else: # copy
144+
copy_mask = 1
145+
copy_pos = action_info.src_token_position
146+
if token_idx != self.vocab.primitive.unk_id:
147+
# both copy and generate from vocabulary
148+
gen_token_mask = 1
149+
150+
app_rule_idx_row.append(app_rule_idx)
151+
app_rule_mask_row.append(app_rule_mask)
152+
153+
token_row.append(token_idx)
154+
gen_token_mask_row.append(gen_token_mask)
155+
copy_pos_row.append(copy_pos)
156+
copy_mask_row.append(copy_mask)
157+
158+
self.apply_rule_idx_matrix.append(app_rule_idx_row)
159+
self.apply_rule_mask.append(app_rule_mask_row)
160+
161+
self.primitive_idx_matrix.append(token_row)
162+
self.gen_token_mask.append(gen_token_mask_row)
163+
164+
self.primitive_copy_pos_matrix.append(copy_pos_row)
165+
self.primitive_copy_mask.append(copy_mask_row)
166+
167+
T = torch.cuda if self.cuda else torch
168+
self.apply_rule_idx_matrix = Variable(T.LongTensor(self.apply_rule_idx_matrix))
169+
self.apply_rule_mask = Variable(T.FloatTensor(self.apply_rule_mask))
170+
self.primitive_idx_matrix = Variable(T.LongTensor(self.primitive_idx_matrix))
171+
self.gen_token_mask = Variable(T.FloatTensor(self.gen_token_mask))
172+
self.primitive_copy_pos_matrix = Variable(T.LongTensor(self.primitive_copy_pos_matrix))
173+
self.primitive_copy_mask = Variable(T.FloatTensor(self.primitive_copy_mask))
174+
175+
@cached_property
176+
def src_sents_var(self):
177+
return nn_utils.to_input_variable(self.src_sents, self.vocab.source,
178+
cuda=self.cuda)
179+
180+
@cached_property
181+
def src_token_mask(self):
182+
return nn_utils.length_array_to_mask_tensor(self.src_sents_len,
183+
cuda=self.cuda)

‎components/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# coding=utf-8
2+
3+
4+
class cached_property(object):
5+
""" A property that is only computed once per instance and then replaces
6+
itself with an ordinary attribute. Deleting the attribute resets the
7+
property.
8+
9+
Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76
10+
"""
11+
12+
def __init__(self, func):
13+
self.__doc__ = getattr(func, '__doc__')
14+
self.func = func
15+
16+
def __get__(self, obj, cls):
17+
if obj is None:
18+
return self
19+
value = obj.__dict__[self.func.__name__] = self.func(obj)
20+
return value

‎components/vocab.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# coding=utf-8
2+
3+
from __future__ import print_function
4+
import argparse
5+
from collections import Counter
6+
from itertools import chain
7+
import torch
8+
9+
class VocabEntry(object):
10+
def __init__(self):
11+
self.word2id = dict()
12+
self.unk_id = 3
13+
self.word2id['<pad>'] = 0
14+
self.word2id['<s>'] = 1
15+
self.word2id['</s>'] = 2
16+
self.word2id['<unk>'] = 3
17+
18+
self.id2word = {v: k for k, v in self.word2id.iteritems()}
19+
20+
def __getitem__(self, word):
21+
return self.word2id.get(word, self.unk_id)
22+
23+
def __contains__(self, word):
24+
return word in self.word2id
25+
26+
def __setitem__(self, key, value):
27+
raise ValueError('vocabulary is readonly')
28+
29+
def __len__(self):
30+
return len(self.word2id)
31+
32+
def __repr__(self):
33+
return 'Vocabulary[size=%d]' % len(self)
34+
35+
def id2word(self, wid):
36+
return self.id2word[wid]
37+
38+
def add(self, word):
39+
if word not in self:
40+
wid = self.word2id[word] = len(self)
41+
self.id2word[wid] = word
42+
return wid
43+
else:
44+
return self[word]
45+
46+
def is_unk(self, word):
47+
return word not in self
48+
49+
@staticmethod
50+
def from_corpus(corpus, size, freq_cutoff=0):
51+
vocab_entry = VocabEntry()
52+
53+
word_freq = Counter(chain(*corpus))
54+
non_singletons = [w for w in word_freq if word_freq[w] > 1]
55+
print('number of word types: %d, number of word types w/ frequency > 1: %d' % (len(word_freq),
56+
len(non_singletons)))
57+
58+
top_k_words = sorted(word_freq.keys(), reverse=True, key=word_freq.get)[:size]
59+
60+
for word in top_k_words:
61+
if len(vocab_entry) < size:
62+
if word_freq[word] >= freq_cutoff:
63+
vocab_entry.add(word)
64+
65+
return vocab_entry
66+
67+
68+
class Vocab(object):
69+
def __init__(self, **kwargs):
70+
self.entries = []
71+
for key, item in kwargs.iteritems():
72+
assert isinstance(item, VocabEntry)
73+
self.__setattr__(key, item)
74+
75+
self.entries.append(key)
76+
77+
def __repr__(self):
78+
return 'Vocab(%s)' % (', '.join('%s %swords' % (entry, getattr(self, entry)) for entry in self.entries))
79+
80+
81+
if __name__ == '__main__':
82+
raise NotImplementedError

‎exp.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# coding=utf-8
2+
from __future__ import print_function
3+
4+
import argparse
5+
import cPickle as pickle
6+
import numpy as np
7+
import time
8+
import math
9+
10+
import sys
11+
import torch
12+
13+
from asdl.asdl import ASDLGrammar
14+
from asdl.lang.py.py_transition_system import PythonTransitionSystem
15+
from components.dataset import Dataset
16+
17+
from model.parser import Parser
18+
19+
20+
def init_config():
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument('--seed', default=5783287, type=int, help='random seed')
23+
parser.add_argument('--cuda', action='store_true', default=False, help='use gpu')
24+
parser.add_argument('--mode', choices=['train', 'train_semi', 'test', 'debug_ls'], default='train', help='run mode')
25+
26+
parser.add_argument('--batch_size', default=10, type=int, help='batch size')
27+
parser.add_argument('--beam_size', default=5, type=int, help='beam size for beam search')
28+
parser.add_argument('--sample_size', default=5, type=int, help='sample size')
29+
parser.add_argument('--embed_size', default=128, type=int, help='size of word embeddings')
30+
parser.add_argument('--action_embed_size', default=128, type=int, help='size of word embeddings')
31+
parser.add_argument('--field_embed_size', default=64, type=int, help='size of word embeddings')
32+
parser.add_argument('--type_embed_size', default=64, type=int, help='size of word embeddings')
33+
parser.add_argument('--ptrnet_hidden_dim', default=32, type=int)
34+
parser.add_argument('--hidden_size', default=256, type=int, help='size of LSTM hidden states')
35+
parser.add_argument('--dropout', default=0., type=float, help='dropout rate')
36+
parser.add_argument('--decoder_word_dropout', default=0.3, type=float, help='word dropout on decoder')
37+
parser.add_argument('--kl_anneal', default=False, action='store_true')
38+
parser.add_argument('--alpha', default=0.1, type=float)
39+
40+
parser.add_argument('--asdl_file', type=str)
41+
parser.add_argument('--vocab', type=str, help='path of the serialized vocabulary')
42+
parser.add_argument('--train_src', type=str, help='path to the training source file')
43+
parser.add_argument('--unlabeled_src', type=str, help='path to the training source file')
44+
parser.add_argument('--unlabeled_tgt', type=str, default=None, help='path to the target file')
45+
parser.add_argument('--train_file', type=str, help='path to the training target file')
46+
parser.add_argument('--dev_file', type=str, help='path to the dev source file')
47+
parser.add_argument('--test_file', type=str, help='path to the test target file')
48+
parser.add_argument('--prior_lm_path', type=str, help='path to the prior LM')
49+
50+
# semi-supervised learning arguments
51+
parser.add_argument('--begin_semisup_after_dev_acc', type=float, default=0., help='begin semi-supervised learning after'
52+
'we have reached certain dev performance')
53+
54+
parser.add_argument('--decode_max_time_step', default=80, type=int, help='maximum number of time steps used '
55+
'in decoding and sampling')
56+
parser.add_argument('--unsup_loss_weight', default=1., type=float, help='loss of unsupervised learning weight')
57+
58+
parser.add_argument('--valid_metric', default='sp_acc', choices=['nlg_bleu', 'sp_acc'],
59+
help='metric used for validation')
60+
parser.add_argument('--log_every', default=10, type=int, help='every n iterations to log training statistics')
61+
parser.add_argument('--load_model', default=None, type=str, help='load a pre-trained model')
62+
parser.add_argument('--save_to', default='model', type=str, help='save trained model to')
63+
parser.add_argument('--save_decode_to', default=None, type=str, help='save decoding results to file')
64+
parser.add_argument('--patience', default=5, type=int, help='training patience')
65+
parser.add_argument('--max_num_trial', default=10, type=int)
66+
parser.add_argument('--uniform_init', default=None, type=float,
67+
help='if specified, use uniform initialization for all parameters')
68+
parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients')
69+
parser.add_argument('--max_epoch', default=-1, type=int, help='maximum number of training epoches')
70+
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
71+
parser.add_argument('--lr_decay', default=0.5, type=float,
72+
help='decay learning rate if the validation performance drops')
73+
parser.add_argument('--lr_decay_after_epoch', default=5, type=int)
74+
parser.add_argument('--reset_optimizer', action='store_true', default=False)
75+
76+
parser.add_argument('--train_opt', default="reinforce", type=str, choices=['reinforce', 'st_gumbel'])
77+
78+
args = parser.parse_args()
79+
80+
# seed the RNG
81+
torch.manual_seed(args.seed)
82+
if args.cuda:
83+
torch.cuda.manual_seed(args.seed)
84+
np.random.seed(args.seed * 13 / 7)
85+
86+
return args
87+
88+
if __name__ == '__main__':
89+
args = init_config()
90+
91+
grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
92+
transition_system = PythonTransitionSystem(grammar)
93+
train_set = Dataset.from_bin_file(args.train_file)
94+
vocab = pickle.load(open(args.vocab))
95+
96+
parser = Parser(args, vocab, transition_system)
97+
parser.train()
98+
if args.cuda: parser.cuda()
99+
optimizer = torch.optim.Adam(parser.parameters(), lr=args.lr)
100+
101+
epoch = train_iter = 0
102+
report_loss = report_examples = 0.
103+
while True:
104+
epoch += 1
105+
epoch_begin = time.time()
106+
107+
for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
108+
batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= 100]
109+
train_iter += 1
110+
optimizer.zero_grad()
111+
112+
loss = -parser.score(batch_examples)
113+
# print(loss.data)
114+
loss_val = torch.sum(loss).data[0]
115+
report_loss += loss_val
116+
report_examples += len(batch_examples)
117+
loss = torch.mean(loss)
118+
119+
loss.backward()
120+
121+
# clip gradient
122+
grad_norm = torch.nn.utils.clip_grad_norm(parser.parameters(), args.clip_grad)
123+
124+
optimizer.step()
125+
126+
if train_iter % args.log_every == 0:
127+
print('[Iter %d] encoder loss=%.5f' %
128+
(train_iter,
129+
report_loss / report_examples),
130+
file=sys.stderr)
131+
132+
report_loss = report_examples = 0.
133+
134+
print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)
135+
# perform validation
136+
print('[Epoch %d] begin validation' % epoch, file=sys.stderr)

‎model/__init__.py

Whitespace-only changes.

‎model/nn_utils.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# coding=utf-8
2+
3+
import torch
4+
import torch.nn.functional as F
5+
import numpy as np
6+
7+
import torch
8+
from torch.autograd import Variable
9+
import numpy as np
10+
11+
12+
def dot_prod_attention(h_t, src_encoding, src_encoding_att_linear, mask=None):
13+
"""
14+
:param h_t: (batch_size, hidden_size)
15+
:param src_encoding: (batch_size, src_sent_len, hidden_size * 2)
16+
:param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size)
17+
:param mask: (batch_size, src_sent_len)
18+
"""
19+
# (batch_size, src_sent_len)
20+
att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2)
21+
if mask is not None:
22+
att_weight.data.masked_fill_(mask, -float('inf'))
23+
att_weight = F.softmax(att_weight, dim=-1)
24+
25+
att_view = (att_weight.size(0), 1, att_weight.size(1))
26+
# (batch_size, hidden_size)
27+
ctx_vec = torch.bmm(att_weight.view(*att_view), src_encoding).squeeze(1)
28+
29+
return ctx_vec, att_weight
30+
31+
32+
def length_array_to_mask_tensor(length_array, cuda=False):
33+
max_len = length_array[0]
34+
batch_size = len(length_array)
35+
36+
mask = np.ones((batch_size, max_len), dtype=np.uint8)
37+
for i, seq_len in enumerate(length_array):
38+
mask[i][:seq_len] = 0
39+
40+
mask = torch.ByteTensor(mask)
41+
return mask.cuda() if cuda else mask
42+
43+
44+
def input_transpose(sents, pad_token):
45+
"""
46+
transform the input List[sequence] of size (batch_size, max_sent_len)
47+
into a list of size (max_sent_len, batch_size), with proper padding
48+
"""
49+
max_len = max(len(s) for s in sents)
50+
batch_size = len(sents)
51+
52+
sents_t = []
53+
masks = []
54+
for i in xrange(max_len):
55+
sents_t.append([sents[k][i] if len(sents[k]) > i else pad_token for k in xrange(batch_size)])
56+
masks.append([1 if len(sents[k]) > i else 0 for k in xrange(batch_size)])
57+
58+
return sents_t, masks
59+
60+
61+
def word2id(sents, vocab):
62+
if type(sents[0]) == list:
63+
return [[vocab[w] for w in s] for s in sents]
64+
else:
65+
return [vocab[w] for w in sents]
66+
67+
68+
def id2word(sents, vocab):
69+
if type(sents[0]) == list:
70+
return [[vocab.id2word[w] for w in s] for s in sents]
71+
else:
72+
return [vocab.id2word[w] for w in sents]
73+
74+
75+
def to_input_variable(sequences, vocab, cuda=False, training=True, append_boundary_sym=False):
76+
"""
77+
given a list of sequences,
78+
return a tensor of shape (max_sent_len, batch_size)
79+
"""
80+
if append_boundary_sym:
81+
sequences = [['<s>'] + seq + ['</s>'] for seq in sequences]
82+
83+
word_ids = word2id(sequences, vocab)
84+
sents_t, masks = input_transpose(word_ids, vocab['<pad>'])
85+
86+
sents_var = Variable(torch.LongTensor(sents_t), volatile=(not training), requires_grad=False)
87+
if cuda:
88+
sents_var = sents_var.cuda()
89+
90+
return sents_var
91+
92+
93+
def variable_constr(x, v, cuda=False):
94+
return Variable(torch.cuda.x(v)) if cuda else Variable(torch.x(v))

‎model/parser.py

+409
Large diffs are not rendered by default.

‎model/pointer_net.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# coding=utf-8
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.utils
6+
from torch.autograd import Variable
7+
import torch.nn.functional as F
8+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
9+
10+
11+
class PointerNet(nn.Module):
12+
def __init__(self, args):
13+
super(PointerNet, self).__init__()
14+
15+
self.src_encoding_linear = nn.Linear(args.hidden_size * 2, args.ptrnet_hidden_dim)
16+
self.query_vec_linear = nn.Linear(args.hidden_size, args.ptrnet_hidden_dim)
17+
self.layer2 = nn.Linear(args.ptrnet_hidden_dim, 1)
18+
19+
def forward(self, src_encodings, src_token_mask, query_vec):
20+
"""
21+
:param src_encodings: Variable(src_sent_len, batch_size, hidden_size * 2)
22+
:param src_token_mask: Variable(src_sent_len, batch_size)
23+
:param query_vec: Variable(tgt_action_num, batch_size, hidden_size)
24+
:return: Variable(src_sent_len, batch_size, tgt_action_num)
25+
"""
26+
27+
# (tgt_action_num, batch_size, src_sent_len, ptrnet_hidden_dim)
28+
h1 = torch.tanh(self.src_encoding_linear(src_encodings.permute(1, 0, 2)).unsqueeze(0) + self.query_vec_linear(query_vec).unsqueeze(2))
29+
# (tgt_action_num, batch_size, src_sent_len)
30+
h2 = self.layer2(h1).squeeze(3)
31+
if src_token_mask is not None:
32+
# (tgt_action_num, batch_size, src_sent_len)
33+
src_token_mask = src_token_mask.unsqueeze(0).expand_as(h2)
34+
h2.data.masked_fill_(src_token_mask, -float('inf'))
35+
36+
ptr_weights = F.softmax(h2, dim=-1)
37+
38+
return ptr_weights

‎model/seq2seq.py

+434
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.