diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 27448b87a922..d4fc294e11d7 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -732,21 +732,20 @@ def show_profile(precision, profile_name): if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()): print("This example requires CUDA with fp8 support.") - exit(1) - - dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16 + else: + dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16 - if args.K and args.K_range is None: - args.K_range = [args.K, args.K] - args.K_step = 1 # doesn't matter as long as it's not 0 + if args.K and args.K_range is None: + args.K_range = [args.K, args.K] + args.K_step = 1 # doesn't matter as long as it's not 0 - torch.manual_seed(0) + torch.manual_seed(0) - validate(32, 32, 32, dtype) - validate(8192, 8192, args.K_range[0], dtype) + validate(32, 32, 32, dtype) + validate(8192, 8192, args.K_range[0], dtype) - proton.start("matmul", hook="triton") - for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench(K, dtype) - proton.finalize() - show_profile(args.prec, "matmul") + proton.start("matmul", hook="triton") + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench(K, dtype) + proton.finalize() + show_profile(args.prec, "matmul") diff --git a/python/tutorials/10-block-scaled-matmul.py b/python/tutorials/10-block-scaled-matmul.py index e664fe39db2a..a97d1118b50c 100644 --- a/python/tutorials/10-block-scaled-matmul.py +++ b/python/tutorials/10-block-scaled-matmul.py @@ -324,17 +324,16 @@ def show_profile(profile_name): parser.add_argument("--format", type=str, choices=["mxfp4", "nvfp4", "mxfp8"], default="nvfp4") args = parser.parse_args() - torch.manual_seed(42) - if not supports_block_scaling(): print("⛔ This example requires GPU support for block scaled matmul") - exit(1) + else: + torch.manual_seed(42) - validate_block_scaled(8192, 8192, 8192, block_scale_type=args.format) + validate_block_scaled(8192, 8192, 8192, block_scale_type=args.format) - if args.bench: - proton.start("block_scaled_matmul", hook="triton") - for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench_block_scaled(K, reps=10000, block_scale_type=args.format) - proton.finalize() - show_profile("block_scaled_matmul") + if args.bench: + proton.start("block_scaled_matmul", hook="triton") + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench_block_scaled(K, reps=10000, block_scale_type=args.format) + proton.finalize() + show_profile("block_scaled_matmul")