@@ -257,9 +257,9 @@ def _make_inputs(self):
257
257
self .X = self .X or T .imatrix ("X" )
258
258
self .transitions = self .transitions or T .imatrix ("transitions" )
259
259
260
- def _step (self , transitions_t , stack_t , buffer_cur_t , tracking_hidden ,
261
- attention_hidden , buffer , ground_truth_transitions_visible ,
262
- premise_stack_tops , projected_stack_tops ):
260
+ def _step (self , transitions_t , ss_mask_gen_matrix_t , stack_t , buffer_cur_t ,
261
+ tracking_hidden , attention_hidden , buffer ,
262
+ ground_truth_transitions_visible , premise_stack_tops , projected_stack_tops ):
263
263
"""TODO document"""
264
264
batch_size , _ = self .X .shape
265
265
@@ -299,8 +299,25 @@ def _step(self, transitions_t, stack_t, buffer_cur_t, tracking_hidden,
299
299
logits_use_cell = self ._predict_use_cell ,
300
300
name = "prediction_and_tracking" )
301
301
302
- # HACK: Sample from action multinomial
303
- mask = ss_mask_gen .multinomial (pvals = actions_t ).nonzero ()[1 ]
302
+ if self .train_with_predicted_transitions :
303
+ # Model 2 case.
304
+ if self .interpolate :
305
+ # Only use ground truth transitions if they are marked as visible to the model.
306
+ effective_ss_mask_gen_matrix_t = ss_mask_gen_matrix_t * ground_truth_transitions_visible
307
+ # Interpolate between truth and prediction using bernoulli RVs
308
+ # generated prior to the step.
309
+ mask = (transitions_t * effective_ss_mask_gen_matrix_t
310
+ + actions_t .argmax (axis = 1 ) * (1 - effective_ss_mask_gen_matrix_t ))
311
+ else :
312
+ # Use predicted actions to build a mask.
313
+ mask = actions_t .argmax (axis = 1 )
314
+ elif self ._predict_transitions :
315
+ # Use transitions provided from external parser when not masked out
316
+ mask = (transitions_t * ground_truth_transitions_visible
317
+ + actions_t .argmax (axis = 1 ) * (1 - ground_truth_transitions_visible ))
318
+ else :
319
+ # Model 0 case.
320
+ mask = transitions_t
304
321
305
322
# Now update the stack: first precompute reduce results.
306
323
if self .model_dim != self .stack_dim :
@@ -425,6 +442,18 @@ def _make_scan(self):
425
442
426
443
# Prepare data to scan over.
427
444
sequences = [transitions ]
445
+ if self .interpolate :
446
+ # Generate Bernoulli RVs to simulate scheduled sampling
447
+ # if the interpolate flag is on.
448
+ ss_mask_gen_matrix = self .ss_mask_gen .binomial (
449
+ transitions .shape , p = self .ss_prob )
450
+ # Take in the RV sequence as input.
451
+ sequences .append (ss_mask_gen_matrix )
452
+ else :
453
+ # Take in the RV sequqnce as a dummy output. This is
454
+ # done to avaid defining another step function.
455
+ outputs_info = [DUMMY ] + outputs_info
456
+
428
457
non_sequences = [buffer_t , self .ground_truth_transitions_visible ]
429
458
430
459
if self .use_attention != "None" and self .is_hypothesis :
0 commit comments