CUDA Graph Support in PyTorch: Improving Performance through Reduced CPU Overhead
Hello everyone, my name is Michael Carilli and today I'll be discussing CUDA graph support in PyTorch. This new feature improves performance by greatly reducing CPU overhead. To start, let's cover the basic concepts of CUDA graphs.
A CUDA graph is a static record of GPU work, which includes kernels, kernel arguments, and kernel execution dependencies. The entire record of work can be replayed with a single CUDA API call. Graphs help because they greatly reduce CPU overhead for static workloads. During ordinary eager execution, each operation incurs CPU overhead from Python, PyTorch backend, and the CUDA driver. This is moderate but it adds up.
As GPU throughput increases relative to CPUs, we often see eager execution bound by CPU overhead. On the left side of our graph, you can see a GPU timeline where the GPU runs each kernel faster than the CPU can feed it, and utilization is poor. However, for network regions, shapes and operations don't change across iterations. Most of the CPU overhead is redundant. We don't need to rerun all the CPU work to compute shapes and launch individual kernels every time.
We can record a CUDA graph of a single dummy iteration, then replay the graph in the training loop, which removes all that CPU work from the critical path. This reduces jitter in distributed training. Jitter is a common cause of poor scaling efficiency because collective operations like AllReduce need to wait for the slowest rank. CPU overhead is a common source of jitter. CPUs are more sensitive to sources of performance nondeterminism, such as OS scheduling, lock contention, and competition with data loading and augmentation.
The left picture shows two GPU timelines of different ranks for a CPU-bound network. The timelines are not well aligned. Rank one runs ahead. It's AllReduce begins first and must spin waiting for rank zero's AllReduce. By removing CPU overhead from the critical path, graphs reduce jitter. The right picture shows GPU timelines of two ranks using graph replay. You can see that they're much better aligned.
Here is a detailed example showing the end of a backward pass and an optimizer step. Each pair of timelines shows the GPU work on two different ranks in the same job. In the graphed case, not only are the timelines denser but relative jitter is reduced. Again, these are not to scale. The graphed version was much faster overall.
Some speedups achieved with graphs for real networks include 1.1X on BERT at our max scale for the MLPerf submission. 1.6X on Mask-RCNN, again at our MLPerf max scale. 2.2X on DLRM, as presented in our deep learning examples.
However, there are some limitations to using graphs. The graphed workload must be static and can only be run without needing midstream CPU communication. This typically means no dynamic shapes or control flow, no syncs, and no essential CPU logic because the graph will ignore it.
So, how can you use graphs in PyTorch? There are three user-facing pieces: CUDAgraph, a primitive graph wrapper, and graph, a straightforward context manager that captures work in its context. Make_graph_callables is a gadget that turns a callable, module or function into a graphed version.
If your whole network is capturable, you should use the graph context manager to capture and replay an entire iteration. Here's how it works. First, run a few warm-up iterations. Then, apply the graph context manager to your network. The graph context manager will record your network's computation graph. You can then replay this graph in your training loop.
However, there is another way to use graphs in PyTorch when you must. Make_graphed_callables does not need manual workload warm-up and applies to each party that you want to capture. Unlike the graph context, make_graph_callables does not need manual copying of new data to placeholder inputs. Graph_callable's backward passes will also run as graphs.
Make_graphed_callables has some convenience features but its replays run less efficiently than full network replays. Best practice is to use the graph context manager for full iteration capture when you can, and make_graph_callables when you must. The graphs' API works with native AMP and native DDP. See the documentation for a minor gotcha.
Graphs should also be composable with JIT. You should be able to JIT a model, warm it up, and then graph it. Although only with the nvfuser back end, and generally I'd regard JIT interoperation as experimental.
In conclusion, CUDA graphs are an exciting new feature in PyTorch that can improve performance by reducing CPU overhead. The basic concepts of CUDA graphs and how they work are covered here today. There is also a discussion on some limitations to using graphs and three user-facing pieces: CUDAgraph, graph, and make_graph_callables.
If you've heard enough to consider CUDA graphs, the documentation is much better than this talk's brief summary. So take a look at it, feel free to file an issue if anything is unclear, and thank you for listening, and enjoy the rest of the conference!