PYTORCH DISTRIBUTED _ YANLI ZHAO

Introduction to PyTorch Distributed Package

PyTorch distributed package is a set of features that facilitate training models in distributed systems. Usually, distributed training can be categorized into data parallelism and model parallelism. PyTorch distributed dataparallel usually trains a model on each device independently for different batches of data and then synchronizes the results at the end of each training iteration and the synchronization is built upon collective calls. Model parallel usually is a process of splitting models across multiple devices and creating pipeline computation to train models across devices. To utilize this, to communicate intermediate results between machines, and this communication is usually built outside of RPC or point-to-point calls.

Newly Added Data Parallel APIs in PyTorch Distributed

The first newly added data parallel API in PyTorch this year is called Zero Redundancy Optimizer, also called ZeRO. The original idea is from DeepSpeed ZeRO. In PyTorch even in a distributed data parallel training, also called DDP training, each process leads to hold parameters, gradients and optimizer local states of the replicated model. Thus it consumes quite a lot of memory and limited the model size that can be trained in a single device. It is observed that local state of some commonly used optimizers could be very large. For example, I-LAN optimizer local state consumes two times memory than model size. Based on this observation, ZeRO shards optimizers stay among dataparallel processes to reduce memory footprint. The higher memory efficiency of ZeRO allows larger scale of models to be trained in a single device.

Training Loop with ZeRO Optimizer

Just like a normal DDP training loop, it first defines a model then wraps the model using DDP API. Then it defines a ZeRO optimizer that wraps regular optimizer class. After that, it's run slow, forward, backward pass in optimizer step. In the optimizer step, ZeRO will update the sharded parameters and local states and then forward pass the updated states to peer processes so that all the processes have the same state at the beginning of next iteration. As you can see, ZeRO can help training larger model in a single device but the scale is still limited because each process still needs to host the whole model parameters and gradients.

Fully Sharded Data Parallel API (FSDP) in PyTorch Distributed

The FSDP API is originally built by Facebook team. We're upstreaming the API and planning to launch it with some improvements in PyTorch 1.11. The FSDP shards parameters, gradients and optimizers states across all data-parallel processes. Layers are usually wrapped with FSDP in a nested way. Wrapped layers are shardedevenly across devices. After all layers are sharded, each process will hold one shard of the whole model. Before first layer's compute, a gather operation is called. Because of these sharded tensor APIs, users do not need to change their model construction code to achieve interlayer parallelism. Users just need to make a minimal change in the training process.

Sharded Tensor: A Generic Abstraction for Parallel Algorithm Implementation

We are envisioning that sharded tensor could be a generic abstraction to help users implement parallel algorithms easily in the long-term. The concept of sharded tensor is an abstraction and is a single process, multiple device style implementation of a tensor that is sharded across multiple devices. Users just need to annotate a tensor with ShardingSpec, and the sharded tensor will take care of all the distributed computations. The simplicity of interlayer parallelism implementation for linear model using sharded tensor can be seen in the example where linear layer is defined and specifies a sharding spec to chunk the data among four devices.

Sharded Tensor Example

The code runs as a normal linear layer underneath, but beneath this code, there are sharded tensors that are built. A gather operation is called before the forward computation. The sharded tensor APIs allow users not to change their model construction code to achieve interlayer parallelism. Users just need to make a minimal change in the training process.

Interlayer Parallelism Implementation

The simplicity of interlayer parallelism implementation for linear model using sharded tensor can be seen from the example where linear layer is defined and specifies a sharding spec to chunk the data among four devices. Annotated for the weight parameter, the code runs as a normal linear layer underneath but beneath this code, there are sharded tensors that are built. A gather operation is called before the forward computation.

Future Vision of Sharded Tensor

We envision sharded tensor could be a generic abstraction to help users implement parallel algorithms easily in the long-term. The goal is to enable generic interlayer parallelism with just a minimal change in the model construction code. This means that users do not need to changetheir model construction codesto achieve interlayer parallelism. They can just makea minimal changein the training process.

"WEBVTTKind: captionsLanguage: en- Hello, everyone.This is Yanli Zhao.I'm a software engineerfrom Facebook AI.Today, I'm going to presentfour new APIsthat have been addedto PyTorch distributed package.First, I will introducethe two new dataparallel APIs calledZero Redundancy Optimizerand fully sharded dataparallel API.After that,I will briefly showCUDA RPC primitive supportand its performance improvement.Lastly, I would liketo show one exampleto explain the conceptof sharded tensorand its future vision.First of all, let's see whatPyTorch distributed package is.PyTorch distributed packageis a set of featuresthat facilitate training modelsin distributed systems.Usually, distributed trainingcan be categorizedinto data paralleland model parallel.PyTorch distributed dataparallel usually trains modelin each device independentlyfor different batches of dataand then synchronizethe resultsat the endof each training iterationand the synchronization is builton top of collective calls.Model parallel usually isa process of splitting modelsacross multiple devices andcreating pipeline computationto train models across devices.To utilize this to communicateintermediate resultsbetween machines,and this communicationis usually built outside of RPCor point-to-point calls.The first newly added dataparallel API in PyTorchthis year is calledZero Redundancy Optimizer,also called ZeRO.The original ideais from DeepSpeed ZeRO.In PyTorch even in a distributeddata parallel training,also called DDP training,each process leadsto hold parameters,gradients and optimizer localstates of the replicated model.Thus it consumesquite a lot of memoryand limited the model sizethat can be trainedin a single device.It is observedthat local state sizeof some commonly used optimizerscould be very large.For example, I-LAN optimizerlocal stateconsumes two times memorythan model size.Based on this observation,ZeRO shards optimizerstays among dataparallel processesto reduce memory footprint.The higher memoryefficiency of ZeROallows larger scale of models tobe trained in a single device.The right example showshow a training loop looks likeby combining DDPwith ZeRO optimizer.Just like a normal DDPtraining loop,it first defines a model thenwrap the model using DDP API.Then it defines a ZeRO optimizerthat wraps regularoptimizer class.After that,it's ran slo-mo, forward,backward pass in optimizer step.In the optimizer step,ZeRO will update the shardedparameters and local statesand then forward pass theupdated states to peer processesso that all the processeshave the same stateat the beginningof next iteration.As you can see,ZeRO can help traininglarger model in a single device,but the scale is still limitedbecause each processstill needs to hostthe whole model parametersand gradients.To further scalelarger model training,fully sharded data parallel alsocalled FSDP API is built up.The FSDP API is originally builtby Facebookteam. We're upstreaming the APIand planning to launch itwith some improvementsin PyTorch 1.11.FSDP shards parameters,gradients and optimizestates acrossall data-parallel processes.Layers are usually wrappedwith FSDP in a nested way.Wrapped layers are shardedevenly across devices.After all layers are sharded,each process will hold one shardof the whole model.Before first layerscompute forward computation,AllGather is called togather weights for that layer,and the gather weights will befreed after forward computationis doneso that the freed memorycan be used fornext layer's computation.In this way,peak memory is savedalso because of CUDAsync operations, next layer'sAllGather can be overlapped withprevious layers' computationand thus can achieve goodtraining performance.Similarly,in backward pass weightsare gathered before computationand are freed after computation.Gradients are synced after eachlayer's backward computation.After the whole modelbackward pass is done,optimizer will update statesfor the local shard.Our experiments showed that upto 1 trillion densetransformer parameterscan be trainedon 256 GPUs using FSDP API.Also, the PyTorch FSDP APIcan automatically wrap layersin a nested way.So as you can seein the right example,for normal DDP training loop,users can just simplyswap DDP API with FSDP APIto train large-scale modelsin a distributed system.Above are two new dataparallel APIsadded in PyTorch distributed.Next, I will briefly introducethe CUDA RPC prim team.RPC API supports point-to-pointcommunication calls.Usually, there is a senderand a receiver.Sender usually sends a remotefunction call with input data.This function with dataare serializedinto payloads and tensors.Once a receiverreceived the message,you can deserialize the messageback to function calls and data,execute the function and sendthe results back to sender.During message transmission,TensorPipe back end can choosethe optimal connection channelto achieve bestcommunication performance.In CUDA RPC, TensorPipecan directly send the tensorsfrom local CUDA memoryto remote CUDA memory.One experiment showedthat there is 34 X speedupcompared to CPU RPC.To enable CUDA RPC, device mapis required to be defined.As you can seein the right example,work zero on CUDA zero needsto make a remote function callon work one that is on CUDA one.Then it specifies the device mapfrom zero to one.With this device map, TensorPipewill enable direct CUDAcommunication and significantlyimprove the performance.CUDA RPC is an importantcommunication primitivefor model parallellike pipeline parallelism,so its performance improvementcan greatly improvethe model parallel training.Now lastly, I would like to showsomething for exampleto explain the conceptof sharded tensorand its future vision.Sharded tensor is an abstractionand is a single process,multiple device styleimplementation of a tensorthat is sharded acrossmultiple devices.The vision is to enablegeneric interlayer parallelism.Users just need to annotatea tensor with ShardingSpec,and the sharded tensorwill take careof all the distributedcomputations.The right example showssimplicity of interlayerparallelism implementationfor linear modelusing sharded tensor.First, linear layer is defined,and it specifies a sharding specto chunk the dataamong four devices.The sharding spec is annotatedfor the weight parameter.Then the code runsas a normal linear layer.Underneath,sharded tensors are built,and a gather operation is calledbefore the forward computation.Because of these shardedtensor APIs,users do not need to changetheir model construction codesto achieveinterlayer parallelism.Users just need to makea minimal changein the training process.So we are envisioning shardedtensorcould be a generic abstractionto help users implementparallel algorithmseasily in the long-term.This is today's talk.Thank you for listening.\n"