Skip to content

Developer guide

Jacob Hinkle edited this page Nov 7, 2024 · 23 revisions

Developer guide

Build nvfuser

# Optionally, fetch a remote branch
$ git fetch origin <BRANCH_NAME>
$ git switch <BRANCH_NAME>

# Only needed initially or when submodules are updated. 
$ git submodule update --init --recursive

$ pip install -e .

See https://github.com/NVIDIA/Fuser/wiki/Building-fuser-project for a more detailed guide.

Test nvfuser

$ ./manual_ci.sh

Benchmark nvfuser

$ bin/nvfuser_bench [--benchmark_filter=<FILTER_REGEX>]

To run only the nvFuser-based benchmarks:

$ bin/nvfuser_bench --benchmark_filter=NvFuserScheduler

Often, you'd like to measure the performance impact of a change.

$ python tools/compare_benchmark.py <baseline_branch_or_commit> <contender_branch_or_commit> <out_dir> -- <args to bin/nvfuser_bench, e.g., --benchmark_filter=NvFuserScheduler>

This script builds and runs both the baseline and the contender, and compares the two results. It also skips the expensive benchmarking when <out_dir> already contains <baseline_branch_or_commit>.json or <contender_branch_or_commit>.json. Below is an example output.

Top 5 improvements:
  Benchmark NvFuserScheduler_LayerNorm_fp16___GRAPH/NvFuserScheduler_LayerNorm_fp16/16/2097152/manual_time changed from 235.70649740466797 seconds 173.04190461790853 (0.73x)
  Benchmark NvFuserScheduler_LayerNorm_fp16___GRAPH/NvFuserScheduler_LayerNorm_fp16/8/2097152/manual_time changed from 138.39395947220675 seconds 102.51789488799382 (0.74x)
  Benchmark NvFuserScheduler_Broadcast_Outer_fp16___GRAPH/NvFuserScheduler_Broadcast_Outer_fp16/8/2097152/manual_time changed from 66.96629544563652 seconds 50.55804438076262 (0.75x)
  Benchmark NvFuserScheduler_Broadcast_Outer_fp16___GRAPH/NvFuserScheduler_Broadcast_Outer_fp16/16/2097152/manual_time changed from 115.90680760401582 seconds 91.29062974831477 (0.79x)
  Benchmark NvFuserScheduler_LayerNorm_LargeHiddenSize_fp32___GRAPH/NvFuserScheduler_LayerNorm_LargeHiddenSize_fp32/8192/34816/manual_time changed from 1788.1167894834052 seconds 1567.6858208396222 (0.88x)

Top 5 regressions:
  Benchmark NvFuserScheduler_Reduction_Outer_fp32___GRAPH/NvFuserScheduler_Reduction_Outer_fp32/1024/8/manual_time changed from 9.868966872637673 seconds 10.735423786849232 (1.09x)
  Benchmark NvFuserScheduler_Broadcast_Outer_fp16___GRAPH/NvFuserScheduler_Broadcast_Outer_fp16/1/320/manual_time changed from 5.658855892095212 seconds 6.134677312527026 (1.08x)
  Benchmark NvFuserScheduler_Reduction_Inner_fp16___GRAPH/NvFuserScheduler_Reduction_Inner_fp16/8/320/manual_time changed from 6.313002350175524 seconds 6.8305892441798015 (1.08x)
  Benchmark NvFuserScheduler_Reduction_Outer_fp16___GRAPH/NvFuserScheduler_Reduction_Outer_fp16/2/4096/manual_time changed from 6.264241068695716 seconds 6.77447848660655 (1.08x)
  Benchmark NvFuserScheduler_BatchNorm_nhwc_fp16___GRAPH/NvFuserScheduler_BatchNorm_nhwc_fp16/2/8/2/manual_time changed from 7.216345496261045 seconds 7.765639299028032 (1.08x)

Saved the histogram of time changes to /home/me/workspace/benchmark_test/histogram.png. 

$ open /home/me/workspace/benchmark_test/histogram.png

histogram.png

Debug workflow for nvFuser

Debugging a failing nvFuser Python script typically follows the following workflow.

  1. An error in compilation is observed when running a python script. This will print a reproducer Python script as part of the error message that defines the fusion and some inputs.
  2. You begin debugging by inspecting where the error came from and isolating the problematic Fusion segment that failed to compile.
  3. You isolate a repro for that failing segment and try to simplify it as much as possible while checking that it still triggers the bad behavior.
  4. (optional) You copy the repro error and describe what you were doing in a new here issue on the nvFuser repo.
  5. Use NVFUSER_DUMP options and gdb to inspect the runtime state of nvFuser to try and determine the root cause and find a fix.

In step 1, the repro will look something like this:

An error occurred while executing nvFuser FusionDefinition 0.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
```python
# CUDA devices:
#  0: NVIDIA H100 80GB HBM3
# torch version: 2.6.0a0+gitffb7a08
# cuda version: 12.6
# nvfuser version: 0.2.22+git6912435
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 28, 32768, 2], contiguity=[None, True, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_tensor(shape=[1, 32768, 2], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    # ...
    fd.add_output(T273)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
    torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
]
fd.execute(inputs)
```

while a compile error might give a message like the following:

Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 182, in execute
    results = self._execute(
              ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp":368, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.

Error from segmentation group 11:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1995, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T125_l___bfloat[ iblockIdx.x846{( ceilDiv(2, blockDim.x) )}, ithreadIdx.x847{blockDim.x}, iS855{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(32768, blockDim.y) ), 16) ), 1) ), gridDim.y) )}, iblockIdx.y854{gridDim.y}, ithreadIdx.y849{blockDim.y}, iUS853{1}, iUR851{16}, bS505{1} ] ca_pos( 6 ) dim: 2 id: iS507{2}
Exception raised from getNonGlobalConsumerStridedIndices at /opt/pytorch/nvfuser/csrc/index_compute.cpp:1995 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x91 (0x7ff45f092448 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
...

This indicates that segmentation group 11 is the one with the problem.

Step 2 is aided by launching your script like NVFUSER_DUMP=python_definition_segments python foo.py. This will print, for each segment, a smaller fusion definition than in the overall repro shown above:

Python definition for segmented group 8:

def nvfuser_fusion_id8(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 32768, 56], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False)
    T1 = fd.ops.squeeze(T0, dims=[0], squeeze_expanded=True)
    T2 = fd.ops.permute(T1, dims=[1, 0])
    fd.add_output(T1, stride_order=[1, 0])
    fd.add_output(T2, stride_order=[0, 1])

Find the group matching the problematic one shown in the error message and this will allow you to cut+paste a new, more targeted repro. Don't forget to modify the inputs to match those expected by the segment fusion definition.

Reference

NVFUSER_DUMP

Use the NVFUSER_DUMP environment variable to control what intermediate results to dump and verbose logging. It can be prepended to any command that launches nvfuser, e.g., bin/nvfuser_tests, bin/nvfuser_bench and python3 a_python_script_that_imports_and_runs_nvfuser.py. csrc/options.cpp lists all dumping options and their meanings.

Examples:

  • NVFUSER_DUMP=cuda_kernel prints the generated CUDA kernels.
  • NVFUSER_DUMP=segmenter_logging prints which scheduler gets used.

gdb

$ python setup.py develop --debug
$ gdb --args bin/nvfuser_tests --gtest_filter=<FILTER>
(gdb) catch throw nvfuser::nvfError
(gdb) r

asan

$ python setup.py develop --build-with-asan
# The ASAN_OPTIONS is needed to work around https://github.com/google/sanitizers/issues/629.
$ ASAN_OPTIONS=protect_shadow_gap=0 <CMD>

Profile heap

# Install tcmalloc and some other tools. 
$ sudo apt install google-perftools

# For me, tcmalloc was installed at /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
$ LD_PRELOAD=<path to libtcmalloc.so> HEAPPROFILE=/tmp/<NAME> <CMD>

The above command should print out "Starting tracking the heap" at the beginning. During or at the end of the program execution, you should be able to see something like "Dumping heap profile to /tmp/..heap". These are the dumped heap profiles to be examined by pprof.

$ sudo apt install golang
$ go install github.com/google/pprof@latest
$ $HOME/go/bin/pprof -dot -output /tmp/<NAME>.dot /tmp/<NAME>.<NUMBER>.heap
$ dot -Tpng /tmp/<NAME>.dot -o /tmp/<NAME>.png

Profile kernels

You can do that with nsys or ncu. For example,

$ nsys profile <CMD>
$ nsys stats --report cuda_gpu_kern_sum <the .nsys-rep file generated by the above command>
$ ncu -k <KERNEL_NAME_FILTER> <CMD>

Unlike nsys, ncu by default tries to stabilize measurement by flushing GPU caches and locking clocks. ncu -h for knobs to change that behavior.

For better UI, you can let ncu export profiling results to .ncu-rep remotely and open that from the Nsight Compute GUI on your host (e.g. a MacBook). Note that Nsight Compute is a different tool from Nsight Systems.

$ ncu -o <OUTPUT_NAME> <OTHER_OPTIONS> <CMD>
...
==PROF== Report: <OUTPUT_NAME>.ncu-rep

When examine nvrtc compiled kernel, it's useful to associate cuda source file with the lowered device code. -lineinfo is useful for that as well as the source code.

$ NVFUSER_ENABLE=kernel_lineinfo NVFUSER_DUMP=cuda_to_file ncu <CMD>