Part 2 - Increase your training throughput with FSDP activation checkpointing

# Using Checkpoint Activations with Fully Sharded Data Parallel (FSDP) in PyTorch

## Introduction

In this article, we will explore how to use checkpoint activations with Fully Sharded Data Parallel (FSDP) in PyTorch. This technique is particularly useful for optimizing memory usage during training large models. The video transcription provided by Les from Meta-AI walks through the process step-by-step, and we will aim to replicate that explanation in this article while ensuring clarity and readability.

---

## Setting Up the Environment

The first step is to ensure you have the correct version of PyTorch installed. As mentioned in the video, activation checkpointing for FSDP is supported starting from the June 18th nightlies. Therefore, any version of PyTorch built on or after June 18th will work.

You will need to import a few key libraries:

```python

import torch

from functools import partial

```

Additionally, you’ll need the core FSDP imports:

```python

from torch.distributed.fsdp.fully_sharded_data_parallel import (

CheckpointWrapper,

CheckpointImpl,

apply_activation_checkpoints_wrapper,

)

```

These imports are standard for working with FSDP and checkpoint activations.

---

## Understanding Checkpoint Wrapping

The core of the implementation involves creating a **checkpoint wrapper**. This wrapper will be used to wrap layers in your model that you want to checkpoint during training. The process involves:

1. Bringing in the `CheckpointWrapper` class.

2. Using `CheckpointImpl`, an enum with two options: `re_entrant` and `non_re_entrant`. For optimal performance, we’ll use `non_re_entrant`.

3. Defining a function to apply these checkpoint wrappers across your model.

The key trade-off here is between performance and memory usage. The `non_re_entrant` option provides the best performance while still allowing you to free up GPU memory during training.

---

## Applying Checkpoint Wrappers

To identify which layers in your model should be wrapped with checkpoints, we’ll use a lambda function (check function). This function will programmatically walk through your model and check if each layer is of the type you want to wrap. For example:

```python

residual = getattr(torch.nn.Transformer, "residual")

```

This assumes you’re working with a transformer-based model where `residual` layers are the ones you want to checkpoint.

Once you’ve defined your check function, you can apply the checkpoint wrappers using:

```python

apply_activation_checkpoints_wrapper(model, wrapper_fn, check_fn)

```

Here, `model` is your sharded model (initialized with FSDP), `wrapper_fn` is the checkpoint wrapper function, and `check_fn` is the lambda function that identifies which layers to wrap.

---

## Initializing FSDP and Sharding the Model

Before applying the activation checkpoints, you must initialize FSDP and shard your model. This involves setting up parameters like:

- **Sharding strategy**: Decide how you want your model weights distributed across GPUs.

- **Precision policy**: Set the desired precision (e.g., 16-bit or 32-bit).

- **Wrap policies**: Define how layers are wrapped for checkpointing.

Once your model is sharded, you can pass it to the `apply_activation_checkpoints_wrapper` function along with your wrapper and check functions.

---

## Performance Considerations

There’s one main trade-off when using activation checkpoints: a slowdown in training time. You should expect a roughly 20-25% increase in training time because you’re effectively performing double computations (both forward and backward passes). However, this trade-off is more than offset by the memory savings.

From experience with models like T5, checkpoint activations have shown to free up anywhere from **33% to 38% of GPU memory**. This freed-up memory can be leveraged to increase your batch size, which in turn improves training throughput by a factor of **2-3x** or more.

This makes activation checkpoints an invaluable tool for improving both training efficiency and overall performance with FSDP.

---

## Conclusion

Using checkpoint activations with FSDP is a powerful technique for optimizing memory usage during the training of large neural networks. By following the steps outlined in this article, you can implement this optimization effectively while balancing the trade-offs between computation time and resource utilization.

"WEBVTTKind: captionsLanguage: enhi everyone my name is les with meta-ai and today we're going to talk about uh using checkpoint activations with fstp or fully charted parallel so i've got a notebook set up here to show you how to do that and started with the first cell up here initial requirement of course is um a version of pytorch that supports activation checkpointing specifically for fstp and that actually happened currently with the june 18th nightlys so anything june 18th and beyond will be sufficient so you go ahead and just bring that in you see we're using june 18th uh we'll need two basic imports torch itself of course uh but we'll also need func tools uh partial because we're gonna do a little bit of function currying and we've got our main fsdp imports so standard items there and now we'll get into actually sort of the core code that we need here for checkpoint wrapping there are three items we're bringing in the checkpoint wrapper itself which is what will literally wrap the layers that we're going to checkpoint with our checkpoint we check my imple which is an enum there's currently two options there and we're just going to use the most performant one and then finally the function for applying the checkpoint wrappers so that will run through your model and checkpoint at the appropriate layers for you so that will bring us to the next question which is what layers should you be checkpointing obviously if you're a model builder directly you can probably have a good intuition of how to do that but we do have an automatic function and it's very similar to the transformer wrapper function so if you haven't seen this tutorial this will be helpful for you to identify how to pick out which layer class you want to use but generically we want to use the layer classes of a given transformer so in this case uh we'll just pretend we're using the deep bit from the pie torch and so in that case the layer class is the residual so we're bringing that in here and then from there we need to make a lambda function and this is a check function that's going to basically as we walk through the model programmatically we want to identify if this is the right layer class that we want to checkpoint wrap so it's very straightforward thing here so we've got residual we brought in from the transformer class or transformer model uh and then we're gonna just identify that so this is your check function there so that's the first thing we need excuse me after that this is fairly boilerplate code at this point but it's basically saying checkpoint wrapper uh we're not gonna offload a cpu uh and we are gonna use a non-re-entrant form uh this is back to that enum that i mentioned uh the non-re-entrant is the best performance so that's what we're gonna stick to uh oops and then from there uh the key last step before you actually check my wrap is to make sure that you have actually already initialized fsdp and basically sharded your model um so this would be an example here where we've got our model set up obviously the wrapping policy image precision policy and the other things that will then present our sharded model take the sharded model and that's what we want to pass to the apply activation checkpoint wrapper function so bring in our model our checkpoint wrapper that we built up above and our check function to identify what layers we want to do or to check my wrap and that's actually it once you apply this it will actually loop through apply the checkpoint wrapping to your model and you are ready to go you can print the model and you'll see within that the appropriate breakout between the sharding as well as the checkpoint we're happening best practices excuse me um there's really two trade-offs or one main trade-off you will should expect to see a roughly 20 to 25 slow down in your training time just because you're effectively doing a double compute both in the forward pass and then redundantly doing compete in the backward pass but in exchange you are freeing up uh from what i've seen with t5 and some other models anywhere from 33 to 38 gpu memory so you can combine that by leveraging the freedom memory by increasing your batch size and of course that increased batch size greatly increases your total training throughput and see increases on the order of two 3x plus in terms of total throughput and therefore total training to our improvements and training time so it's a very significant tool in your tool chest for improving your training time and experience with fstp so hope that helpshi everyone my name is les with meta-ai and today we're going to talk about uh using checkpoint activations with fstp or fully charted parallel so i've got a notebook set up here to show you how to do that and started with the first cell up here initial requirement of course is um a version of pytorch that supports activation checkpointing specifically for fstp and that actually happened currently with the june 18th nightlys so anything june 18th and beyond will be sufficient so you go ahead and just bring that in you see we're using june 18th uh we'll need two basic imports torch itself of course uh but we'll also need func tools uh partial because we're gonna do a little bit of function currying and we've got our main fsdp imports so standard items there and now we'll get into actually sort of the core code that we need here for checkpoint wrapping there are three items we're bringing in the checkpoint wrapper itself which is what will literally wrap the layers that we're going to checkpoint with our checkpoint we check my imple which is an enum there's currently two options there and we're just going to use the most performant one and then finally the function for applying the checkpoint wrappers so that will run through your model and checkpoint at the appropriate layers for you so that will bring us to the next question which is what layers should you be checkpointing obviously if you're a model builder directly you can probably have a good intuition of how to do that but we do have an automatic function and it's very similar to the transformer wrapper function so if you haven't seen this tutorial this will be helpful for you to identify how to pick out which layer class you want to use but generically we want to use the layer classes of a given transformer so in this case uh we'll just pretend we're using the deep bit from the pie torch and so in that case the layer class is the residual so we're bringing that in here and then from there we need to make a lambda function and this is a check function that's going to basically as we walk through the model programmatically we want to identify if this is the right layer class that we want to checkpoint wrap so it's very straightforward thing here so we've got residual we brought in from the transformer class or transformer model uh and then we're gonna just identify that so this is your check function there so that's the first thing we need excuse me after that this is fairly boilerplate code at this point but it's basically saying checkpoint wrapper uh we're not gonna offload a cpu uh and we are gonna use a non-re-entrant form uh this is back to that enum that i mentioned uh the non-re-entrant is the best performance so that's what we're gonna stick to uh oops and then from there uh the key last step before you actually check my wrap is to make sure that you have actually already initialized fsdp and basically sharded your model um so this would be an example here where we've got our model set up obviously the wrapping policy image precision policy and the other things that will then present our sharded model take the sharded model and that's what we want to pass to the apply activation checkpoint wrapper function so bring in our model our checkpoint wrapper that we built up above and our check function to identify what layers we want to do or to check my wrap and that's actually it once you apply this it will actually loop through apply the checkpoint wrapping to your model and you are ready to go you can print the model and you'll see within that the appropriate breakout between the sharding as well as the checkpoint we're happening best practices excuse me um there's really two trade-offs or one main trade-off you will should expect to see a roughly 20 to 25 slow down in your training time just because you're effectively doing a double compute both in the forward pass and then redundantly doing compete in the backward pass but in exchange you are freeing up uh from what i've seen with t5 and some other models anywhere from 33 to 38 gpu memory so you can combine that by leveraging the freedom memory by increasing your batch size and of course that increased batch size greatly increases your total training throughput and see increases on the order of two 3x plus in terms of total throughput and therefore total training to our improvements and training time so it's a very significant tool in your tool chest for improving your training time and experience with fstp so hope that helps\n"