Skip to content

Commit

Permalink
Step 2 of 2: Removal of 3D soft argmax calculations from METRABS Head…
Browse files Browse the repository at this point in the history
… for NCNN compatibility (isarandi#49)

* Removal of 3D soft argmax calculations from METRABS Head

Modification of METRABS Head to remove 3D soft argmax functionality. This is now computed in Loom SDK to support NCNN compatibility.

METRABS Trainer class had to be modified to do these calculations during training and testing time.

* Migrating duplicate functionality to separate function

Migrating shared functionality in MetrabsTrainer.forward_train and MetrabsTrainer.forward_test to _shared_process_inps

* fix error

---------

Co-authored-by: ylee <[email protected]>
  • Loading branch information
2 people authored and GitHub Enterprise committed Aug 29, 2023
1 parent 78006a9 commit 9751e10
Showing 1 changed file with 58 additions and 34 deletions.
92 changes: 58 additions & 34 deletions src/models/metrabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ def __init__(self, backbone, joint_info):
def call(self, inp, training=None):
image = inp
features = self.backbone(image, training=training)
coords2d, coords3d = self.heatmap_heads(features, training=training)

return coords2d, coords3d

'''
Updates to reflect new function signature of self.heatmap_heads which now only returns feature vectors.
Note that this may not be compatible with simcc heads.
'''
volume_3d = self.heatmap_heads(features, training=training)
return volume_3d

@tf.function
def predict_multi(self, image):
Expand All @@ -59,24 +64,10 @@ def __init__(self, n_points):
self.n_outs = [self.n_points, FLAGS.depth * self.n_points]
self.conv_final = keras.layers.Conv2D(filters=sum(self.n_outs), kernel_size=1)

## Modified implementation of METRABS head does not have 3D soft argmax and only returns feature vectors.
def call(self, inp, training=None):
x = self.conv_final(inp)
logits2d, logits3d = tf.split(x, self.n_outs, axis=tfu.channel_axis())

if FLAGS.data_format == 'NCHW':
current_format = 'b h w (d j)' if tfu.get_data_format() == 'NHWC' else 'b (j d) h w'
logits3d = einops.rearrange(logits3d, f'{current_format} -> b j d h w', j=self.n_points)
coords3d = tfu.soft_argmax(tf.cast(logits3d, tf.float32), axis=[4, 3, 2])
else:
current_format = 'b h w (d j)' if tfu.get_data_format() == 'NHWC' else 'b (d j) h w'
logits3d = einops.rearrange(logits3d, f'{current_format} -> b h w d j', j=self.n_points)
coords3d = tfu.soft_argmax(tf.cast(logits3d, tf.float32), axis=[2, 1, 3])


coords3d_rel_pred = models.util.heatmap_to_metric(coords3d, training)
coords2d = tfu.soft_argmax(tf.cast(logits2d, tf.float32), axis=tfu.image_axes()[::-1])
coords2d_pred = models.util.heatmap_to_image(coords2d, training)
return coords2d_pred, coords3d_rel_pred
return x


class MetrabsTrainer(models.model_trainer.ModelTrainer):
Expand All @@ -87,25 +78,62 @@ def __init__(self, metrabs_model, joint_info, joint_info2d=None, global_step=Non
self.joint_info_2d = joint_info2d
self.model = metrabs_model

## Estimating self.n_raw_points based on FLAGS.output_upper_joints
if FLAGS.output_upper_joints:
self.n_raw_points = 8
else:
self.n_raw_points = 32 if FLAGS.transform_coords else joint_info.n_joints

if FLAGS.data_format == 'NCHW':
inp = keras.Input(shape=(3, None, None), dtype=tfu.get_dtype())
else:
inp = keras.Input(shape=(None, None,3), dtype=tfu.get_dtype())

## Variable used for feature splitting into 2d and 3D [8, 8*8]
self.n_outs = [self.n_raw_points, FLAGS.depth * self.n_raw_points]

self.model(inp, training=False)

def forward_train(self, inps, training):
def _shared_process_inps(self, inps, training):
preds = AttrDict()

image_both = tf.concat([inps.image, inps.image_2d], axis=0)
features = self.model.backbone(image_both, training=training)
coords2d_pred_both, coords3d_rel_pred_both = self.model.heatmap_heads(
features, training=training)
batch_sizes = [t.shape.as_list()[0] for t in [inps.image, inps.image_2d]]
preds.coords2d_pred, preds.coords2d_pred_2d = tf.split(
coords2d_pred_both, batch_sizes, axis=0)
preds.coords3d_rel_pred, preds.coords3d_rel_pred_2d = tf.split(
coords3d_rel_pred_both, batch_sizes, axis=0)
if training:
image_both = tf.concat([inps.image, inps.image_2d], axis=0)
features = self.model.backbone(image_both, training=training)
else:
features = self.model.backbone(inps.image, training=training)

volume_3d = self.model.heatmap_heads(features, training=training)
logits2d, logits3d = tf.split(volume_3d, self.n_outs, axis=tfu.channel_axis())

if FLAGS.data_format == 'NCHW':
current_format = 'b h w (d j)' if tfu.get_data_format() == 'NHWC' else 'b (j d) h w'
logits3d = einops.rearrange(logits3d, f'{current_format} -> b j d h w', j=self.n_raw_points)
coords3d = tfu.soft_argmax(tf.cast(logits3d, tf.float32), axis=[4, 3, 2])
else:
current_format = 'b h w (d j)' if tfu.get_data_format() == 'NHWC' else 'b (d j) h w'
logits3d = einops.rearrange(logits3d, f'{current_format} -> b h w d j', j=self.n_raw_points)
coords3d = tfu.soft_argmax(tf.cast(logits3d, tf.float32), axis=[2, 1, 3])

preds.coords3d_rel_pred = models.util.heatmap_to_metric(coords3d, training)
coords2d = tfu.soft_argmax(tf.cast(logits2d, tf.float32), axis=tfu.image_axes()[::-1])
preds.coords2d_pred = models.util.heatmap_to_image(coords2d, training)

if training:
batch_sizes = [t.shape.as_list()[0] for t in [inps.image, inps.image_2d]]
preds.coords2d_pred, preds.coords2d_pred_2d = tf.split(
preds.coords2d_pred, batch_sizes, axis=0)
preds.coords3d_rel_pred, preds.coords3d_rel_pred_2d = tf.split(
preds.coords3d_rel_pred, batch_sizes, axis=0)

preds.coords3d_pred_abs = tfu3d.reconstruct_absolute(
preds.coords2d_pred, preds.coords3d_rel_pred, inps.intrinsics)

return preds

def forward_train(self, inps, training):

preds = self._shared_process_inps (inps, training)

if FLAGS.transform_coords:
l2j = self.model.latent_points_to_joints
Expand Down Expand Up @@ -207,12 +235,8 @@ def compute_metrics(self, inps, preds):
return models.eval_metrics.compute_pose3d_metrics_j8(inps, preds) if FLAGS.output_upper_joints else models.eval_metrics.compute_pose3d_metrics(inps, preds)

def forward_test(self, inps):
preds = AttrDict()
features = self.model.backbone(inps.image, training=False)
preds.coords2d_pred, preds.coords3d_rel_pred = self.model.heatmap_heads(
features, training=False)
preds.coords3d_pred_abs = tfu3d.reconstruct_absolute(
preds.coords2d_pred, preds.coords3d_rel_pred, inps.intrinsics)

preds = self._shared_process_inps (inps, training=False)

if FLAGS.transform_coords:
l2j = self.model.latent_points_to_joints
Expand Down

0 comments on commit 9751e10

Please sign in to comment.