diff --git a/torchbenchmark/models/nanogpt/model.py b/torchbenchmark/models/nanogpt/model.py index 3f01aa6e1b..d8d87c3758 100644 --- a/torchbenchmark/models/nanogpt/model.py +++ b/torchbenchmark/models/nanogpt/model.py @@ -349,7 +349,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): ) # Create AdamW optimizer and use the fused version if it is available fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters - use_fused = fused_available and device_type == "cuda" + use_fused = fused_available and device_type in ["cuda", "xpu"] extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW( optim_groups, lr=learning_rate, betas=betas, **extra_args diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index 16c4840256..15c4bd62e6 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -37,9 +37,9 @@ def check_precision( if precision == "bypass": return True if precision == "fp16": - return model.device == "cuda" and hasattr(model, "enable_fp16") + return model.device in ["cuda", "xpu"] and hasattr(model, "enable_fp16") if precision == "tf32": - return model.device == "cuda" + return model.device in ["cuda", "xpu"] if precision == "amp": return True if precision == "fx_int8": @@ -47,9 +47,9 @@ def check_precision( if precision == "bf16": return True if precision == "amp_fp16": - if model.test == "eval" and model.device == "cuda": + if model.test == "eval" and model.device in ["cuda", "xpu"]: return True - if model.test == "train" and model.device == "cuda": + if model.test == "train" and model.device in ["cuda", "xpu"]: return hasattr(model, "enable_amp") or is_staged_train_test(model) if precision == "amp_bf16": if model.test == "eval" and model.device == "cpu": @@ -87,13 +87,13 @@ def get_precision_default(model: "torchbenchmark.util.model.BenchmarkModel") -> if ( hasattr(model, "DEFAULT_EVAL_CUDA_PRECISION") and model.test == "eval" - and model.device == "cuda" + and model.device in ["cuda", "xpu"] ): return model.DEFAULT_EVAL_CUDA_PRECISION if ( hasattr(model, "DEFAULT_TRAIN_CUDA_PRECISION") and model.test == "train" - and model.device == "cuda" + and model.device in ["cuda", "xpu"] ): return model.DEFAULT_TRAIN_CUDA_PRECISION if hasattr(model, "DEFAULT_PRECISION"):