diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 119c304c8..d7b945bf6 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -190,7 +190,7 @@ def __init__( state: Dict[Any, Any] = {} param_group: Dict[str, Any] = { "params": [], - "lr": emb_module.get_learning_rate(), + "lr": emb_module.optimizer_args.learning_rate, } params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {} @@ -383,7 +383,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( state: Dict[Any, Any] = {} param_group: Dict[str, Any] = { "params": [], - "lr": emb_module.get_learning_rate(), + "lr": emb_module.optimizer_args.learning_rate, } params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {} diff --git a/torchrec/modules/fused_embedding_modules.py b/torchrec/modules/fused_embedding_modules.py index 064fdea36..7a22cbf69 100644 --- a/torchrec/modules/fused_embedding_modules.py +++ b/torchrec/modules/fused_embedding_modules.py @@ -68,7 +68,7 @@ def __init__( # noqa C901 state: Dict[Any, Any] = {} param_group: Dict[str, Any] = { "params": [], - "lr": emb_module.get_learning_rate(), + "lr": emb_module.optimizer_args.learning_rate, } params: Dict[str, torch.Tensor] = {}