Skip to content

Commit c698440

Browse files
spcypptfacebook-github-bot
authored andcommitted
Unifying TBE API using List (Frontend) - reland
Summary: Re-land of D68055168 with fixes. ------ Previous landing of D68055168 causes S498612 with error ``` RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch. ``` Assertion failure point: https://fburl.com/code/q1xg0w8m **Root cause:** This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following: - Only in JIT script - Only fails for module attributes **Explanation**: 1. Module is scripted e.g., [`module_factory/model_materializer_full_sync.py`](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615) 2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots` 3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409)) 4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111). 5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same. {F1976090817} See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726) 6. Since JIT type sees that they are of different types, the assertion is triggered. **Note**: The JIT type is unable to figure out only when *Tensor* is added to the Class. - Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type `OptimizerArgs` and does not fail. {F1976090829} See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738) - Adding any tensors to the class `OptimizerArgs` would cause errors as above. - Adding floats, ints or bools to the class `OptimizerArgs`is fine. ------- __**Solution**__ We discussed several workarounds but the best solution is to - keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float. - create a `learning_rate_tensor` in invokers before passing it to `lookup_function`. Moreover, as we just realize many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases. This is also the best way to maintain backward compatibility. Note that, `learning_rate_tensor` is always created on CPU, so there should be no host-device synchronization. Differential Revision: D71010630
1 parent b7a4e51 commit c698440

7 files changed

+337
-351
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def generate() -> None:
447447
ssd_optimizers.append(optim)
448448

449449
BackwardSplitGenerator.generate_backward_split(
450-
ssd_tensors=ssd_tensors, **optimizer
450+
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
451451
)
452452
BackwardSplitGenerator.generate_rocm_backward_split()
453453

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

+1
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ Tensor {{ embedding_cuda_op }}(
601601

602602
{%- if "learning_rate" in args.split_kernel_arg_names %}
603603
// convert `learning rate` to float since `learning rate` is float in kernels
604+
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
604605
const float learning_rate = learning_rate_tensor.item<float>();
605606
{%- endif %}
606607

fbgemm_gpu/codegen/training/python/lookup_args.template

+2-31
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class CommonArgs(NamedTuple):
5050
ssd_tensors: Dict[str, torch.Tensor]
5151
{%- endif %}
5252

53-
53+
# Do not add a parameter of Type tensor here. It will cause JIT script error due to a bug in PyTorch.
54+
# See more detail in D71010630.
5455
class OptimizerArgs(NamedTuple):
5556
stochastic_rounding: bool
5657
gradient_clipping: bool
@@ -108,36 +109,6 @@ class CommonArgsPT2(NamedTuple):
108109
ssd_tensors: Dict[str, torch.Tensor]
109110
{%- endif %}
110111

111-
class OptimizerArgsPT2(NamedTuple):
112-
"""
113-
Optimizer arguments for PT2 interface
114-
"""
115-
stochastic_rounding: bool
116-
gradient_clipping: bool
117-
max_gradient: float
118-
max_norm: float
119-
learning_rate_tensor: torch.Tensor
120-
eps: float
121-
beta1: float
122-
beta2: float
123-
weight_decay: float
124-
weight_decay_mode: int
125-
eta: float
126-
momentum: float
127-
counter_halflife: int
128-
adjustment_iter: int
129-
adjustment_ub: float
130-
learning_rate_mode: int
131-
grad_sum_decay: int
132-
tail_id_threshold: float
133-
is_tail_id_thresh_ratio: int
134-
total_hash_size: int # Required for OptimType.NONE
135-
weight_norm_coefficient: float
136-
lower_bound: float
137-
regularization_mode: int
138-
use_rowwise_bias_correction: bool # Used for OptimType.ADAM
139-
140-
141112
class Momentum(NamedTuple):
142113
dev: torch.Tensor
143114
host: torch.Tensor

0 commit comments

Comments
 (0)