Skip to content

Commit 6beee1a

Browse files
committed
Revert "fat stack: tear out scheduled sampling and replace with full sampling"
This reverts commit 350a059.
1 parent 0981f34 commit 6beee1a

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

python/spinn/fat_stack.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def _make_inputs(self):
257257
self.X = self.X or T.imatrix("X")
258258
self.transitions = self.transitions or T.imatrix("transitions")
259259

260-
def _step(self, transitions_t, stack_t, buffer_cur_t, tracking_hidden,
261-
attention_hidden, buffer, ground_truth_transitions_visible,
262-
premise_stack_tops, projected_stack_tops):
260+
def _step(self, transitions_t, ss_mask_gen_matrix_t, stack_t, buffer_cur_t,
261+
tracking_hidden, attention_hidden, buffer,
262+
ground_truth_transitions_visible, premise_stack_tops, projected_stack_tops):
263263
"""TODO document"""
264264
batch_size, _ = self.X.shape
265265

@@ -299,8 +299,25 @@ def _step(self, transitions_t, stack_t, buffer_cur_t, tracking_hidden,
299299
logits_use_cell=self._predict_use_cell,
300300
name="prediction_and_tracking")
301301

302-
# HACK: Sample from action multinomial
303-
mask = ss_mask_gen.multinomial(pvals=actions_t).nonzero()[1]
302+
if self.train_with_predicted_transitions:
303+
# Model 2 case.
304+
if self.interpolate:
305+
# Only use ground truth transitions if they are marked as visible to the model.
306+
effective_ss_mask_gen_matrix_t = ss_mask_gen_matrix_t * ground_truth_transitions_visible
307+
# Interpolate between truth and prediction using bernoulli RVs
308+
# generated prior to the step.
309+
mask = (transitions_t * effective_ss_mask_gen_matrix_t
310+
+ actions_t.argmax(axis=1) * (1 - effective_ss_mask_gen_matrix_t))
311+
else:
312+
# Use predicted actions to build a mask.
313+
mask = actions_t.argmax(axis=1)
314+
elif self._predict_transitions:
315+
# Use transitions provided from external parser when not masked out
316+
mask = (transitions_t * ground_truth_transitions_visible
317+
+ actions_t.argmax(axis=1) * (1 - ground_truth_transitions_visible))
318+
else:
319+
# Model 0 case.
320+
mask = transitions_t
304321

305322
# Now update the stack: first precompute reduce results.
306323
if self.model_dim != self.stack_dim:
@@ -425,6 +442,18 @@ def _make_scan(self):
425442

426443
# Prepare data to scan over.
427444
sequences = [transitions]
445+
if self.interpolate:
446+
# Generate Bernoulli RVs to simulate scheduled sampling
447+
# if the interpolate flag is on.
448+
ss_mask_gen_matrix = self.ss_mask_gen.binomial(
449+
transitions.shape, p=self.ss_prob)
450+
# Take in the RV sequence as input.
451+
sequences.append(ss_mask_gen_matrix)
452+
else:
453+
# Take in the RV sequqnce as a dummy output. This is
454+
# done to avaid defining another step function.
455+
outputs_info = [DUMMY] + outputs_info
456+
428457
non_sequences = [buffer_t, self.ground_truth_transitions_visible]
429458

430459
if self.use_attention != "None" and self.is_hypothesis:

0 commit comments

Comments
 (0)