Skip to content

Commit

Permalink
[DOC] Fix generating docs for tutorials (#5850)
Browse files Browse the repository at this point in the history
Do not `exit` when generating tutorials, which causes the whole
documentation pipeline exit.
  • Loading branch information
Jokeren authored Feb 7, 2025
1 parent 4554cea commit 80d9a94
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
27 changes: 13 additions & 14 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,21 +732,20 @@ def show_profile(precision, profile_name):

if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()):
print("This example requires CUDA with fp8 support.")
exit(1)

dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16
else:
dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16

if args.K and args.K_range is None:
args.K_range = [args.K, args.K]
args.K_step = 1 # doesn't matter as long as it's not 0
if args.K and args.K_range is None:
args.K_range = [args.K, args.K]
args.K_step = 1 # doesn't matter as long as it's not 0

torch.manual_seed(0)
torch.manual_seed(0)

validate(32, 32, 32, dtype)
validate(8192, 8192, args.K_range[0], dtype)
validate(32, 32, 32, dtype)
validate(8192, 8192, args.K_range[0], dtype)

proton.start("matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype)
proton.finalize()
show_profile(args.prec, "matmul")
proton.start("matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype)
proton.finalize()
show_profile(args.prec, "matmul")
19 changes: 9 additions & 10 deletions python/tutorials/10-block-scaled-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,16 @@ def show_profile(profile_name):
parser.add_argument("--format", type=str, choices=["mxfp4", "nvfp4", "mxfp8"], default="nvfp4")
args = parser.parse_args()

torch.manual_seed(42)

if not supports_block_scaling():
print("⛔ This example requires GPU support for block scaled matmul")
exit(1)
else:
torch.manual_seed(42)

validate_block_scaled(8192, 8192, 8192, block_scale_type=args.format)
validate_block_scaled(8192, 8192, 8192, block_scale_type=args.format)

if args.bench:
proton.start("block_scaled_matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench_block_scaled(K, reps=10000, block_scale_type=args.format)
proton.finalize()
show_profile("block_scaled_matmul")
if args.bench:
proton.start("block_scaled_matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench_block_scaled(K, reps=10000, block_scale_type=args.format)
proton.finalize()
show_profile("block_scaled_matmul")

0 comments on commit 80d9a94

Please sign in to comment.