Lightning Talk - Lessons from Using Pytorch 2.0 Compile in IBM's Watsonx.AI Inference - Antoni Martin

**Investing in PyTorch: A Journey of Optimization and Innovation**

Our team has been working tirelessly to optimize our model, and we're excited to share our journey with you. We started by leveraging the functional collectives that people have been mentioning, including all this new work on the py distributed work on the keynote. It was a crucial step in fixing some issues with Inductor, such as memory layouts, Trace conditions, and assets triggering graph breaks. These fixes were essential to getting our model working smoothly.

We were thrilled to see our results, which showed that we had achieved great performance with 7B and 13B parameters. However, when we tried to scale up to 7v with Tensor Parallel, we encountered some overhead on the CPU from launching all these Collective kernels, which caused issues for us. We needed a solution, so we turned to cographs, which is the recommended approach when dealing with too much CPU overhead.

Unfortunately, even after trying cographs, we hit segmentation faults and out-of-memory errors. It was clear that our model was still not optimized enough. One potential solution we found was enforcing the KB cache to be contiguous, but there were still some issues that needed to be addressed. We will continue working on these problems in the coming days.

**Future Plans and Goals**

As we move forward, we have several key goals in mind. First, we want to ensure that our model can be compiled without graph breaks using PyTorch's torch no_graph mode. This is a critical step in making our model more efficient and scalable. We're currently working on resolving this issue, and it's being tracked on GitHub.

Another area of focus for us is exploring ways to make our model work for training, not just inference. As Elias mentioned earlier, the graph is showing that every kernel invocation takes 65 kilobytes of memory. With a model like ours, which has 80 layers with six to seven kernel invocations each, it's clear that we need to find solutions to this problem.

We're also committed to adding optimizations to our model, including n-tensor support and taking advantage of new inference techniques such as quantization, page attention, and all-gather patterns. These features will help us make the most of PyTorch's capabilities and ensure that our model remains competitive in the field.

**Conclusion and Future Directions**

In conclusion, we're excited about the progress we've made on optimizing our model using PyTorch. By leveraging functional collectives, addressing issues with Inductor, and exploring new optimization techniques, we've been able to achieve great performance. However, there's still much work to be done, particularly in terms of scaling up to larger models and resolving the CPU overhead issue.

We're committed to continuing this work and making significant contributions back to the PyTorch community. We believe that by working together and sharing our knowledge, we can create more efficient and scalable machine learning models. Our goal is to make PyTorch the go-to framework for researchers and developers in the field, and we're confident that with continued effort and innovation, we can achieve this goal.

**Acknowledgments**

We'd like to extend our gratitude to everyone involved in making this project a success. Dewing PRS, our code writer, has been instrumental in ensuring that all the pieces fit together seamlessly. Our team at Meta and IBM has provided invaluable support and guidance throughout this journey. We're also thankful for Elias's presentation earlier today, which highlighted some key areas of focus for us.

We're excited to continue working on these issues and exploring new optimization techniques. If you have any questions or would like to discuss further, we're happy to chat.

"WEBVTTKind: captionsLanguage: enI'm from IBM research and my talk today is going to be about uh lessons from using pytorch compile for inference in like specifically our Lama to model implementation so why L to well first of all it's um currently being used by our clients already and it's important for us to like you know our clients like try this efficiently it's important that you know we can like make in general large models like Lama 7B or like even bigger models run efficiently for us and for this purpose we decided to collaborate with pytorch and make an open- Source efficient implementation of llama so the issue we found is that when we tried the Lama implementation that's open source and available in like the on GitHub we found like you know there was it was hard to combine all these new optimizations that PGE 2 or 2.1 have been offering into single like single model for inference right so we combine things like the basic torch compile we added like sdpa we added tensor parallel and we added Cod graphs and as I mentioned this is open source so please look at the QR code or look at the link we just put in here uh either of them like you can just it will bring to your our GitHub repo that we just open sourced last week with our Lama 2 implementation now um first of all I'll talk about what the performance numbers look like and as you'll see um the green and the red are the compile Plus sdpa numbers and as you can see for most of the models compile gets you a at least a 2X Improvement in performance and for 70b which is the model we run with tensor parallel and H gpus adding Cuda graphs gets you an extra like 50% performance on top so basically you go from like in 7B like you can see like um 40 milliseconds per token down to like 30 to 40 milliseconds per token depending on your badge size and like how much like uh sequence length you're running with um now you will notice that um the advantage from using like uh sdpa in in particular flash B2 is not that big and in some cases actually you get worse performance uh if anyone here has read The Flash decoring blog post published like a few days ago they explain exactly why that's the case but basically it's because Flash attention is not paralyzing in the right way for inference but one of our next steps is going to be probably like working out with them and just getting uh this new Flash decoding technique implemented in our model to get this extra performance from sdpa that we should be getting um so what do we do to like make all this stuff work right um first of all uh as people mentioned in the Keynote the most important thing to make compile get your models faster is making sure you have no Graphics right so the L implementation in uh GitHub had one main issue which was the the rotor embeddings which are run in every single layer um use complex numbers and complex numbers are not currently supporting inductor that meant we had to reimplement the rotor embeddings uh using basically like M malls and like operations that support in aan and in inductor to make sure that you know our code had no Graphics then another important thing for us was being able to like choose which sdpa implementation to use like we want to make sure like we can choose like flash two or memory efficient or like the unfused implementation to like be able to doation studies for example like the pl just showed and we learned that using the context speaker like usually like something like torch. Cuda sdp backend something like enable flash that will not work without Graphics either so we had to also look at how to do that without graph RS and finally um when we got everything running we realized that there was a a bugging compile that does a new way to like compute the guards for like making sure like the modules in the code do not change over like each iteration uh it was causing everything to recompile on every call so basically instead of like having one k compilation or like three C compilations if you're comping with Dynamic we would have like 500 compilations or like you know and compilations if you running end times uh for sdpa uh once we managed to get Flash to running we also found another issue that was that uh every time you change the size on your inputs or your KB cache you would be like Rec compiling like wait I am I compiling with dnamic but then we found that flash specifically and not other not any other Kel for sdpa was actually foring compilation uh Ed Yang fixed that issue for us like a week ago and it's been working great since then uh for tensor parallel um first of all we tried using D tensors but then we realized that the D tensor implementation in py did not have support for torch compile which I think is not longer true as of like the last three four days I think it just got merged but of course when we were preparing this presentation uh we had to like Implement our own tension parallel implementation that was uh compatible with torch compound right because again what I want to do here is com combine all these optimizations in one right so what did we do um we used the functional collectives that um people have been mentioning uh like all this new like work on the py distributed work on the keynote and we had to like look and like you know fix a few issues with the inductor way like how inductor actually compiles these things like there some issu memory layouts we like Trace conditions there was some issues with like um assets triggering graph breaks that we had to fix as well and in the end you know we got everything without Graphics finally we were like okay uh we got all the numbers for like 7B 13B everything was working great great results and we were like okay uh let's try 7v right 7v with tensor par and everything was like okay there's clearly some overhead on the CPU like from launching all these Collective kernels that's causing issues for us so let's try cographs which is the solution that everyone recommends when you have like too much CPU overhead right and when we tried we just started hitting like segmentation folds out of memory errors and like one solution we found is that uh the KB cache which is an input to our model needs to be enforced to be contigous but there's still some issues that we will work be working out and I'll talk about it in the next slide so next steps well um as I just mentioned um D tensor compile support just landed on py night l so that's one thing we will be trying next uh another one thing that we want to try is currently all this work that we just buil all this model like only works for inference because the code only compiles without graph breakes using torch no grab uh we're actually looking for a fix for that and that's the issue on GitHub that is tracking that progress um at some point we want to be able to like you know use our lab implementation for training not just for inference and that's basically one of the main blockers for that um Elias on his talk like 3 minutes ago mentioned that uh there's like this graph is show like where like every kernel invocation takes 65 kiloby um the Lama 7tv model has 80 layers each layer has like about six to seven kernal invocations so if you have like all these invocations and you invoke your model with like different sizes and different like shapes and everything like a thousand times a thousand different shapes like 4,000 different shapes depending on your context length or whatever you will get your like your graphics card memory like filled pretty fast so we need to look at solutions for this and as I mentioned there's all these like segmentation folds and stuff we have been finding which we also need to work on uh finally um we want to ensure that you know keep adding optimizations to this model and some of them we already mentioned in presentations today uh like you know we want to add n tensor support we want to make sure that you know um we can try all these like new d10 optimizations that uh make like you know like um will console was talking about all these optimizations for the M mall and like the all gather pattern that we definitely have in this code here so we also want to make sure that we can take advantage of that we want to try things like um quantization of course page attention all these new like uh inference optimizations that people are talking about in the last few weeks and with always like you know making sure that everything is compatible with what we already have so finally um our conclusion here is that uh we have decided to like uh invest in pytorch and make sure that you know any optimizations we make are as far as we can contribute back to the core p and you know that's why we open source our L implementation that's uh also also showed that you know if you are careful in designing your models you can like add performance from different optimizations although it takes some work to make sure everything works together uh which is like this whole compos composability theme and then uh I would like to also thank like all these people from meta and IBM who basically like you know I'm like I'm here presenting this but like all these bug fixes all this works been like would not have been possible without every people in here like Dewing PRS making code writing code um ensuring that you know everything works together helping us like theug code uh make sure you know everything is there and like ready for like uh me to show here today right and with this I'm pretty much done if you have any questions free to ask although I know it's almost lunchtime so yeah thank you yep so how to right so I I don't know how much I can share about it like about IBM specific hardw like my bosses are here so they can maybe like say yes or no take it off okay we can take it offline I guess but I can answer like you know I can explain later if you want yeahI'm from IBM research and my talk today is going to be about uh lessons from using pytorch compile for inference in like specifically our Lama to model implementation so why L to well first of all it's um currently being used by our clients already and it's important for us to like you know our clients like try this efficiently it's important that you know we can like make in general large models like Lama 7B or like even bigger models run efficiently for us and for this purpose we decided to collaborate with pytorch and make an open- Source efficient implementation of llama so the issue we found is that when we tried the Lama implementation that's open source and available in like the on GitHub we found like you know there was it was hard to combine all these new optimizations that PGE 2 or 2.1 have been offering into single like single model for inference right so we combine things like the basic torch compile we added like sdpa we added tensor parallel and we added Cod graphs and as I mentioned this is open source so please look at the QR code or look at the link we just put in here uh either of them like you can just it will bring to your our GitHub repo that we just open sourced last week with our Lama 2 implementation now um first of all I'll talk about what the performance numbers look like and as you'll see um the green and the red are the compile Plus sdpa numbers and as you can see for most of the models compile gets you a at least a 2X Improvement in performance and for 70b which is the model we run with tensor parallel and H gpus adding Cuda graphs gets you an extra like 50% performance on top so basically you go from like in 7B like you can see like um 40 milliseconds per token down to like 30 to 40 milliseconds per token depending on your badge size and like how much like uh sequence length you're running with um now you will notice that um the advantage from using like uh sdpa in in particular flash B2 is not that big and in some cases actually you get worse performance uh if anyone here has read The Flash decoring blog post published like a few days ago they explain exactly why that's the case but basically it's because Flash attention is not paralyzing in the right way for inference but one of our next steps is going to be probably like working out with them and just getting uh this new Flash decoding technique implemented in our model to get this extra performance from sdpa that we should be getting um so what do we do to like make all this stuff work right um first of all uh as people mentioned in the Keynote the most important thing to make compile get your models faster is making sure you have no Graphics right so the L implementation in uh GitHub had one main issue which was the the rotor embeddings which are run in every single layer um use complex numbers and complex numbers are not currently supporting inductor that meant we had to reimplement the rotor embeddings uh using basically like M malls and like operations that support in aan and in inductor to make sure that you know our code had no Graphics then another important thing for us was being able to like choose which sdpa implementation to use like we want to make sure like we can choose like flash two or memory efficient or like the unfused implementation to like be able to doation studies for example like the pl just showed and we learned that using the context speaker like usually like something like torch. Cuda sdp backend something like enable flash that will not work without Graphics either so we had to also look at how to do that without graph RS and finally um when we got everything running we realized that there was a a bugging compile that does a new way to like compute the guards for like making sure like the modules in the code do not change over like each iteration uh it was causing everything to recompile on every call so basically instead of like having one k compilation or like three C compilations if you're comping with Dynamic we would have like 500 compilations or like you know and compilations if you running end times uh for sdpa uh once we managed to get Flash to running we also found another issue that was that uh every time you change the size on your inputs or your KB cache you would be like Rec compiling like wait I am I compiling with dnamic but then we found that flash specifically and not other not any other Kel for sdpa was actually foring compilation uh Ed Yang fixed that issue for us like a week ago and it's been working great since then uh for tensor parallel um first of all we tried using D tensors but then we realized that the D tensor implementation in py did not have support for torch compile which I think is not longer true as of like the last three four days I think it just got merged but of course when we were preparing this presentation uh we had to like Implement our own tension parallel implementation that was uh compatible with torch compound right because again what I want to do here is com combine all these optimizations in one right so what did we do um we used the functional collectives that um people have been mentioning uh like all this new like work on the py distributed work on the keynote and we had to like look and like you know fix a few issues with the inductor way like how inductor actually compiles these things like there some issu memory layouts we like Trace conditions there was some issues with like um assets triggering graph breaks that we had to fix as well and in the end you know we got everything without Graphics finally we were like okay uh we got all the numbers for like 7B 13B everything was working great great results and we were like okay uh let's try 7v right 7v with tensor par and everything was like okay there's clearly some overhead on the CPU like from launching all these Collective kernels that's causing issues for us so let's try cographs which is the solution that everyone recommends when you have like too much CPU overhead right and when we tried we just started hitting like segmentation folds out of memory errors and like one solution we found is that uh the KB cache which is an input to our model needs to be enforced to be contigous but there's still some issues that we will work be working out and I'll talk about it in the next slide so next steps well um as I just mentioned um D tensor compile support just landed on py night l so that's one thing we will be trying next uh another one thing that we want to try is currently all this work that we just buil all this model like only works for inference because the code only compiles without graph breakes using torch no grab uh we're actually looking for a fix for that and that's the issue on GitHub that is tracking that progress um at some point we want to be able to like you know use our lab implementation for training not just for inference and that's basically one of the main blockers for that um Elias on his talk like 3 minutes ago mentioned that uh there's like this graph is show like where like every kernel invocation takes 65 kiloby um the Lama 7tv model has 80 layers each layer has like about six to seven kernal invocations so if you have like all these invocations and you invoke your model with like different sizes and different like shapes and everything like a thousand times a thousand different shapes like 4,000 different shapes depending on your context length or whatever you will get your like your graphics card memory like filled pretty fast so we need to look at solutions for this and as I mentioned there's all these like segmentation folds and stuff we have been finding which we also need to work on uh finally um we want to ensure that you know keep adding optimizations to this model and some of them we already mentioned in presentations today uh like you know we want to add n tensor support we want to make sure that you know um we can try all these like new d10 optimizations that uh make like you know like um will console was talking about all these optimizations for the M mall and like the all gather pattern that we definitely have in this code here so we also want to make sure that we can take advantage of that we want to try things like um quantization of course page attention all these new like uh inference optimizations that people are talking about in the last few weeks and with always like you know making sure that everything is compatible with what we already have so finally um our conclusion here is that uh we have decided to like uh invest in pytorch and make sure that you know any optimizations we make are as far as we can contribute back to the core p and you know that's why we open source our L implementation that's uh also also showed that you know if you are careful in designing your models you can like add performance from different optimizations although it takes some work to make sure everything works together uh which is like this whole compos composability theme and then uh I would like to also thank like all these people from meta and IBM who basically like you know I'm like I'm here presenting this but like all these bug fixes all this works been like would not have been possible without every people in here like Dewing PRS making code writing code um ensuring that you know everything works together helping us like theug code uh make sure you know everything is there and like ready for like uh me to show here today right and with this I'm pretty much done if you have any questions free to ask although I know it's almost lunchtime so yeah thank you yep so how to right so I I don't know how much I can share about it like about IBM specific hardw like my bosses are here so they can maybe like say yes or no take it off okay we can take it offline I guess but I can answer like you know I can explain later if you want yeah\n"