Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let model precision for XPU device align with CUDA #2587

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchbenchmark/models/nanogpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
)
# Create AdamW optimizer and use the fused version if it is available
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == "cuda"
use_fused = fused_available and device_type in ["cuda", "xpu"]
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(
optim_groups, lr=learning_rate, betas=betas, **extra_args
Expand Down
12 changes: 6 additions & 6 deletions torchbenchmark/util/extra_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ def check_precision(
if precision == "bypass":
return True
if precision == "fp16":
return model.device == "cuda" and hasattr(model, "enable_fp16")
return model.device in ["cuda", "xpu"] and hasattr(model, "enable_fp16")
if precision == "tf32":
return model.device == "cuda"
return model.device in ["cuda", "xpu"]
if precision == "amp":
return True
if precision == "fx_int8":
return model.device == "cpu" and hasattr(model, "enable_fx_int8")
if precision == "bf16":
return True
if precision == "amp_fp16":
if model.test == "eval" and model.device == "cuda":
if model.test == "eval" and model.device in ["cuda", "xpu"]:
return True
if model.test == "train" and model.device == "cuda":
if model.test == "train" and model.device in ["cuda", "xpu"]:
return hasattr(model, "enable_amp") or is_staged_train_test(model)
if precision == "amp_bf16":
if model.test == "eval" and model.device == "cpu":
Expand Down Expand Up @@ -87,13 +87,13 @@ def get_precision_default(model: "torchbenchmark.util.model.BenchmarkModel") ->
if (
hasattr(model, "DEFAULT_EVAL_CUDA_PRECISION")
and model.test == "eval"
and model.device == "cuda"
and model.device in ["cuda", "xpu"]
):
return model.DEFAULT_EVAL_CUDA_PRECISION
if (
hasattr(model, "DEFAULT_TRAIN_CUDA_PRECISION")
and model.test == "train"
and model.device == "cuda"
and model.device in ["cuda", "xpu"]
):
return model.DEFAULT_TRAIN_CUDA_PRECISION
if hasattr(model, "DEFAULT_PRECISION"):
Expand Down