NeurIPS 2020 - JAX Ecosystem Meetup

The Excitement Around JAX: Growth and Development

We know we're personally very invested um i guess the jax team has grown and so you know and it's working well things like places like d-mind and google brain and uh research more broadly so um you know i don't think anyone wants to like uh uh i think everyone is is uh the best days of jack's are ahead i guess what i'm saying jax is growing uh the folks working on it are very passionate uh as you've seen and so you know i wouldn't worry at all about um jack's going away or something like that.

Awesome and we are we are at time but we might have time for one more question um and it looks like robin had his hand raised oh yeah this looks amazing i was wondering uh what do you think are the stumbling blocks that new people to jax often hit upon like are there certain things that are more that are often confusing for people or cause issues when they're just getting started we have a great link for that there's something called the jacks the sharp bits um collab on the docks and then i'll i'll i'll link that in the chat in addition i guess i'll add one from the d mine side i think sort of you know as like one of the strengths obviously your jax is this um this bet on functional programming and sort of you're raising everything into a functional paradigm where things are stateless and you don't need to worry about side effects um i guess as normal like i guess we decided to sort of you know diverge from that slightly with some of our libraries such as haku and the reason was you know we had other considerations we needed to balance we had so many users that are already using um sonnets as their way of defining and reasoning about neural networks within tensorflow and we really wanted to be able to maintain that api so certainly there's some cognitive overhead in you know sort of writing something that looks object oriented which is sort of converted for you into functions and sometimes it takes some users um i uses some time to ramp up to but i mean that was just sort of the the trade-off that we decided worked worked for us as an organization um certainly there are other jax neural network libraries that take different approaches that other people find work better for them and you know in the spirit of incremental buy and cross-compatibility you know you're free to use any part of any of those with any of our libraries as well i think for the most part that's been working for people.

Thank you i would add maybe uh a common bit that was also uh mentioned on the chat uh is the pseudo-random number generation um that is something that uh as a first newcomer to jax might might be a little bit uh surprising if you're used to numpy where you don't have to reason about the state of the random number generator almost at all i would say though that while it can be surprising at first it actually pays off in the long run so once you get used to actually this way of of dealing with random number generation i i think and it is in some sense at least personally i kind of ended up with the conclusion that this was actually the right way all along in some sense and that it there is also a very nice um document in in the jax documentation about the the prng design explaining the reasons and the motivations for for the the way jax deals with random numbers i think that is a great read i think as a newcomer if you're maybe a bit puzzled at the beginning by random numbers just read that doc because it will it will think clarify a lot yeah i wholeheartedly agree with what matteo is saying like once you've used jack's rngs like all other rngs that maintain their own stay silent they become very scary all of a sudden so definitely cool.

So we we are at time um the thank you all so much for your questions i there might be one more slide left uh is is this correct mateo haha there we go cool so so thank you everybody for coming um absolutely please make sure to share your jax projects um on social media whatever your favorite flavor of social media is um github is also social um and it is my favorite social uh hangout place um but anything that you create please tag it with jack's ecosystem so that we can make sure to see it and to share it and if you're interested in contributing there are many issues that are good for first contributors on the jax repo many places to get involved and we look forward to seeing what you create thank you so much everyone for attending thank you to jack's core and for all of the jax ecosystem teams for presenting your work and for getting us all very excited i know i am significantly more excited somehow than i was even an hour ago so thank you everybody have a great day have a great nurips and see you soon

"WEBVTTKind: captionsLanguage: enit off to mateo mateo do you want to take it over yeah thank you very much paige so today i'll be presenting with my friends and colleagues so dave mihaila and jun fabio and anteo and basically our objective is to give you an overview of the ecosystem and the tools and libraries that we've been building around jax and to try to give you a feel and an understanding of why we we as a group and as a company are quite excited about jax and also hopefully help answer some questions from you about uh maybe what are your doubts what are your concerns and whether or not jax could be also useful for your own research and for your own projects so let me start maybe with a with a very brief introduction to to jax and to at a very like bird's eye why it might be an exciting framework to work in and build machine learning projects and research on so at its core uh many of you will be already familiar but uh let me just repeat really the fundamentals at its core jax is a python library that is designed for high performance numerical computing and it among it's like key ingredients it basically supports three main areas one is differentiation so forward and reverse mode automatic differentiation of arbitrary numerical functions and here we have a whole lot of like jack's primitives like grad hessian jack jacobian forward jacob rev that all allow you to basically implement and exploit the automatic differentiation capabilities of jax the second broad family of primitives that jax offers to its users are related to vectorization and automatic vectorization so in a lot of machine learning we rely on a single instruction multiple data style of programming where we typically want to apply the same type of transformations to a whole batch of identical data so we might for instance be computing a loss over a big batch of the input and output samples and the jax basically makes life a lot easier to researchers and and practitioner of machine learning by exposing some very neat and simple abstraction to simplifying single instruction multiple data programming these are mostly v-map and p-map and we're gonna hear some some examples today of how these can can really make the difference in writing clean but expressive code and finally the third big area inside jax is jit compilation so just in time compilation and basically jax is built on top of xla and uses xla to just in time compile your code and make both faster your cpu code if you're running a cpu but also give you transfer parent gpu or cloud cpu and acceleration the the important feature of jax like if you are to basically take away just one message from what is jax is that all these these these functions and these abstractions are implemented as composable program transformation so what what does this actually mean let's let's give a very very concrete simple example and consider a numerical function that just squares an input and sums to the second argument of that same function you could brighten python this way and x and y could be for instance numpy arrays the value of this function is of course evaluated like any python function you just call the function pass some inputs and and the the first input will be squared and sum to the second one what would computing a gradient look like with jax well the neat thing of jax is that the gradient of this function is comp is also a function so you just call jax.grad on a function and you get back something that is still a python object it's still a python function but that if you call now computes the derivative of the original function in a very neat and transparent way and importantly these uh this transformation is composable so you could for instance compute the second order gradient by calling grad of grad of the original function and this again gives you just the python function but then if you pass x and y gives you the second order gradient in those two values of x and y similarly you can mix and match this with other program transformations like compilation so if you just compute if you just call jit of grad of grad of the original function this gives you a compiled second order gradient function that again will you can just call as any other python function but the first time you call it will trace the code and compile it using xla and then the second time you're all following times you will call the same function it will execute a pre-compiled xla code and therefore be much faster and of course this doesn't stop here all the jack's primitives are neatly composable so you could for instance batching your your function you just write it as if you were just dealing with single examples but if you remap that function again you get a function that now expects a batch of inputs and computes the same function of the on the entire batch and finally if you want to then execute these batched compile second order gradient calculation in parallel and multiple multiple accelerators for instance this could be multiple gpus you just need to chain one more now you p-map a v-mapped legit grand of a grad of a function and again this is just a python function that you can call on your inputs and i hope this gives you a flavor of how composable gradient program transformations enable jax to expose a very thin set of primitives that are easy to understand but that you can then combine in very rich and very powerful programs that can support quite quite powerful use cases and before i i leave to my colleagues to delve into into more details i just want to conclude with a couple of final remarks on why jax is so convenient and i think there are a few things to remember one is of course this composable nature of all its uh abstractions but it's also good to remember that all of the numerical functions are written in with a with a syntax that is fully consistent with numpy so it's literally you're just literally just writing numpy code but then you can transform it using these program transformations this means that jax can be quite familiar even when starting because python and numpy are widely used and also in a couple of of points i want to highlight is one day is not a vertically integrated framework like like many other ml frameworks it's really focused on getting this core bit of numerical computing right and then it's it still provides you with a very rich ecosystem built around it and a community around it that can give you the um you can make other like other things and classical things like building neural networks easy for you and this is what david is going to be uh talking about just now hi um yeah thanks a lot mateo that's a great overview um so my name is david i'm a research engineer at deepmind and i'm here today to talk a little bit about our deep mindjacks ecosystem which is essentially a collection of libraries that we've been developing together over the past 18 months or so now that build on top of jacks as matteo just been explaining um and sort of added additional functionalities with functionality rather than specific machine learning research thanks likely too um so why an ecosystem so deepmind researchers have had a lot of great initial success using jax so this is normally a nervous hosted event many of the papers that we're presenting in europe this year um use jacks under the hood for to produce the results um as engineers and researchers at deepmind we're constantly asking the question how can we continue to support and accelerate this great work and this early success that we've seen um there are a few considerations that go into this the first is mateo which is saying is the jax is not a vertically integrated machine learning framework this is the strength it does one thing very it does a core set of things very very well and what we want to do is um what we want to do is build on top of this and continue to build our own libraries on top to meet the specific machine learning research and our particular needs um whatever we build needs to support our rapidly evolving requirements you know for instance a specific focus on reinforcement learning among other things and where possible we want to strive for consistency and compatibility with frameworks and tools we've released in the past for instance the compatible of tensorflow things like sonos and truffle um the solution we've arrived that is is an ecosystem which is basically a collection of libraries of reusable and unopinionated jax components um each library is intended to do one specific thing well and importantly supports incremental buy-in so for instance we're going to show you libraries for neural networks libraries for optimization for reinforcement learning from day one we've made sure to develop these in a way such that you can pick and choose the pieces you want without being locked into any decisions about other libraries or other tools from within the same library and of course where possible open sourcing everything to enable research sharing and reproducibility of our results that we build on top of the jacks ecosystem thanks a lot i'm quickly going to go over three examples of libraries we've built although there's others and i encourage you all to check them out now for more details we we posted a blog recently on the deepmind.com website that goes through this in more detail um so haiku is our jax library for neural networks so as mateo is highlighting much of the strength of jax comes from essentially this functional programming paradigm whereby we have stateless functions and we can compose function transforms so it's just a gravity map of whatever returns a function that function can be statelessly applied to to whatever to some numpy array or some jackson umpi array um often some abstractions that are common to machine learning things like neural network trainable parameters and state uh fit the object oriented paradigm potentially better or at least is how people are used to thinking about this so higher kind of tries to bridge the gap between the object-oriented world and the functional world of jax it essentially provides tools by which you can take um stateful objects such as neural networks and convert these into initialization and application functions that can then be used directly in jacks our researchers have had great success on trivially porting previous tensorflow results such as um the the alphas here for instance into into jacks and it's a widely used and mature framework both within deepmind but um more widely within google and the public community as well next slide opdax so optax is for optimization so much like neural networks it often makes sense to think about neural networks in an object-oriented paradigm as having state um optdx is essentially a gradient processing optimization library that's core provides sort of a zoo of common optimizers that people know and love things like sgd art and momentum whatever it may be and these are importantly these are defined as this chain operation of simple optics primitives so sort of fundamental gradient processing primitives they're common to many of the optimizers and many other things as well so optics provides these for simple out-of-the-box implementations changing these components to give us build optimizers and gives the user a lot of functionality to do this themselves in addition it provides a lot of useful utilities that allow the trivial application of gradient-based updates to neural network parameters i'm not just haiku parameters for instance using out of the jacks neural network library but sort of true to the spirit of incremental buy-in and cross-compatibility um all or at least most of the popular jackson neural network libraries you may be familiar with should be compatible with objects anything that represents the parameters as a tree of numpy arrays essentially again widely used within deepmind with a growing user base outside next slide please so as you probably know deepmind cares quite a lot about reinforcement learning i guess one of the one of the complexities of reinforcement learning is it's very easy to get things wrong for instance if you place a stop gradient in the wrong place which can be very simple to do the entire experiment is going to fall apart um the idea of relax is essentially to provide a library of trusted battle tested reinforcement learning components so not even algorithms with sub-components of algorithms for instance on the right-hand side you can see how a q-learning primitive and a square loss primitive can be combined to define a simple loss function that you might want to use and essentially this provides a substrate by which our researchers can use components they trust and have algorithms that they trust in addition to quickly contributing their own ideas back to be shared amongst the broader community um a lot of the emphasis we face on this library has been on readability so for instance what does this mean now we use unicode everywhere so if you have a look at our doctrines you know we're not just scared to reproduce our doctrines to look like textbook examples but also not being afraid to repeat ourselves as well if we want people who land on a function for td error to see where the td error is not need to chase this down three or four different functions next slide so i don't have time today to go over all the libraries that we've released as part of our ecosystem but there are others for instance the giraffe library for graph neural networks which we released last week as well as our chex library for testing and reliability here are some links here i encourage you to check them out in addition to this world view that we're telling you today about an ecosystem i suppose there's an additional shell around the outside of this which are frameworks that are then built on top of these libraries um some examples from deepmind are jacksline and acne for supervised learning in rl i encourage you to check these out and in addition to the work from deepmind we've been talking about today there's a lot of really great work in building being built on top of the jacks ecosystem both being google and in the public more generally so for the rest of this presentation i'm going to hand over people who are going to talk more about their sort of research results have been based on jax and jack's ecosystem are starting with michaela who's going to talk about gans and generative models um thanks abel so i hope everyone can hear me and i'm super excited to talk uh today about what i think makes jax amazing for generative models and specifically gans next slide please um and the reason i chose gans is because they're a bit different than the standard view that we might have of other generative models of this paradigm of one model one loss in the gen case we have two models we have the players the generator which is producing data and the discriminator which learns to distinguish between this generated data and the real data and the goal of the discriminator is to be really good at distinguishing between these two and the goal of the generator is to basically be to fool this discriminator into thinking that the data that it generates is real and um the the way to to do this is either via a zero-sum game or via other types of losses they don't need to be necessarily connected as long as the underlying principle is this we have one player that learns to distinguish and one player that learns to to full and hence the name adversarial next slide please now if we want to implement this in in jacks one thing that we notice straight off the bat is that when when you uh implement again you might want to do multiple discriminator updates for each generator update um or vice versa often more of the discriminator and this is super trivial in jax you just have to write the python for loop for each of the players we all know how to do that and inside this uh python for loop for each of the players you decide to update the parameters of that player and the really neat things about this that also mateo highlighted earlier is that we now can get gradient functions so not gradient values but we can ask jax well jax can i please have the gradient function which is the gradient of the discriminator loss with respect to the discriminate parameters and then i can evaluate that function at the current values of both of the discriminator parameter generator parameter and what other data batch i might have and so on and crucially here there's no need for any stop gradient on the parameters of the generator even though the value of that loss will depend on these parameters because we just ask jax for the gradient with respect to the first argument which is the discriminator parameter so this is very nice very easy to reason about and almost like a pseudo code if i if you ask me to write a pseudo again algorithm pseudocode it will look very similar to this and i think this is really really nice now once we have the gradients we can pass them to an optimizer like an octagon optimizer that they've talked about getting an update getting our new discriminator parameters and then doing exactly the same thing for the generator this time only that we use a different loss the generator loss and we get a new gradient function the generator loss function gradient with respect to the generator parameters so this is all uh very neat next slide please but one thing that i want to highlight here we now have direct access to gradients so for example if i want to see how the norm of the gradient at different layers looks like i can easily do that i can just add one or two lines of numpy um code that i'm familiar with and i can get i can get those statistics and i think this is very useful in terms of allowing us researchers to build intuition about what our models are doing in a very simple way next slide please something else that i think is maybe specific to gans that can that jax highlights very nicely is it gives you more control and it makes you think about the kind of decisions that you're trying to make so one type of decision for example is well when i update the discriminator what should i do about the generator state and what i mean by state is not the generator parameters but for example let's say the generator batch norm statistics do or do i or do i not want to update the statistics and let's leave aside for now whether you actually want to do this or not in gans the point here is that when you implement uh ants in checks it makes it forces you to think about this and it forces you to make the right decision so briefly looking at um at this quick implementation that i made up the discriminator loss we see that we have a four pass through the generator this returns us a new generator state and now when i return from this discriminator loss i can say well i want to use this new generator state in conjunction with the new discriminator state or next slide please i can just decide to ignore that and return the state of the generator that was given to me very very easy again i have to think what i want and getting to what i want is just a matter of changing a few lines of code here too instead of digging somewhere deep for something that i might not have access to next slide please and last but not least i think i i just want to say this one more time of this functional approach making making code very close to math especially in generative modeling when we often have a lot of distributions gradients of things that depend on distributions makes it much more easy to to reason about and i'm just going to highlight two things that mateo talked about earlier which is being able to vectorize functions and computing jacobians so if i want to compute um let's say gradients with respect to some parameters of an expectation of a distribution that depends on these parameters this is not revealed we can use the score function estimator but that has higher variance and if we can for some distributions we can use their parameterization trick so this is just a trick that allows us to rewrite that expectation with respect to another distribution that doesn't depend on these parameters and push the parameters inside now even with this new form now we see that we still have an expectation so we have to compute some functions over multiple samples from this expectation what we can do is we can expect the user to just always pass us in vectorize functions which can be okay if it's only neural networks but sometimes it can be something else we don't only have to compute gradients with respect to neural network functions or we can do it for them and jax really allows us to do that by just saying well i'm gonna vectorize the function for you you can pass in anything that you want and i specifically find this very very useful for tests because even though we might use neural networks in our um in our experiments often we want to test for a very complicated nonlinear function and we don't want to spend the time vectorizing all of our uh our functions in our tests and again if i don't want to compute a gradient and i want to compute a jacobian it's just a matter of changing a few characters and i get uh get what i want so i think this is really really um convenient and very easy to reason about and that's it for me and i'm going to give it up to june who's going to talk about meta gradients thanks um hi everyone i'm jun i'm a research scientist at deepmind and i'm going to briefly talk about how i use jacks for my recent work next slide please so in this work on discovering reinforcement learning algorithms we try to meta learn a reinforcement learning update rule from a distribution of agents and environments like this figure but there were several technical challenges because of the unique problem set up over here thanks so first we wanted to simulate many independent learning agents where each agent is interacting with its own set of environments but this is already quite unusual because normally in reinforcement learning we just consider one learning agent but here we wanted to simulate multiple lifetimes of agents simultaneously next place and at the same time we wanted to apply the same update rule which is the meta learner to all learning agents in a completely synchronized way next please and also we wanted to calculate meta gradient over this asian update procedure which requires calculating the second order gradient and finally if you we wanted to massively scale up this approach by increasing the number of learning agents without introducing much extra computational cost so these were quite challenging so to address these challenges we um we used jacks and jacks actually allowed us to easily handle all these challenges in the next slide we i'm going to briefly describe how we implemented this system using jets and at the beginning we first implemented every environment in jax so that we can apply pre-map and pmf later and as you can see from this figure we implemented a single update rule and single asian and single jacks environment interaction which was quite simple and if you go to the next slide um we and then we added a feedback to implement multiple jacks environments for one learning agent and in the next slide we added another vmap at the outer scope to implement multiple learning agents like this figure where each agent still interacts with its own set of environments and finally we added pmap to implement multiple copies of the same computation graph and then distribute this across multiple tpu cores like this figure with a shared update guru so in this figure each tpu core essentially has its own set of agents and its own set of environments but they all share the same update rule and they all calculate the meta gradient and in a completely synchronized way and then perform meta update if you go to the next slide and this is the pseudo code of our algorithm so here the top part implements the agent update and the bottom part implements the meta gradient calculation and the yellow part is the free maps and pmap that i mentioned in the previous slides so if we just remove these remaps and p-maps then it essentially becomes the single asian optical environment interaction but by just adding a few b maps and peanuts like this figure we can easily convert the the simplest single interaction implementation to the massively parallel system that i showed in the previous slide so in the actual experiment we used the 16 core ppu and using this tpu we could simulate 1 000 parallel learning agents and 60 000 parallel environments with one global shared update rule and we were able to process more than 3 million stats per second using 16 core gpu so next slide please to just summarize we had several interesting challenges because of this unusual problem setup but by using jackson tpu we were actually able to handle all these challenges quite easily without putting much engineering effort which was quite nice for us and also we got quite interesting results out of this project so if you are interested in this project please check out the paper and also come by the postal session tomorrow yeah that's all i have thanks i'm going to hand over fabio and d uh hi everyone so i'm theo i'm a research scientist at dmind and i'm going to present uh some work we're doing on search and model bazaar with fabio um next okay so here we're shifting gears a little bit so so far we have looked at applications of jacks which uh crucially leverages gradient computation capabilities uh and you may ask you know is this the main use of jax is this gradient uh kind of base computation the main thing that we care about so here we showcase another application where jax enables fast research iteration that uses gradient but it's not the core of the compute and it's basically multi-colored research in a model-based r setting as seen in alpha zero and mu zero so we are training rl agents um which uh plan effectively using an internal mode of the world create the form of a plan which is a sequence of action that they optimize over and then both use that plan to act in the real world as well as to update uh policy priorities that they use to help to guide the plan and there are challenges in doing so um the first is that in a typical kind of neuroline neural network guided research you have a tight integration of the control logic in your network machinery and this is actually quite tricky to debug um there's challenges around scalability and parallelism because uh research is inherently a sequential algorithm five you will mention more on this later um and as typical for model-based rl setups there is a lot of issues around the data pipelines a lot of not so much issues as much as design choices user replay the share of synthetic data versus real data how to use data for policy learning versus model learning and so on and so you need a framework which enables you to quickly test loads of different ideas and that kind of setup next time so next yeah so and next again thank you um so sorry just one one before just a picture with one so um in the last few years we've seen kind of an explosion of uh a space which finds itself at the intersection of search based algorithm which evaluates a different type of solution using kind of more discrete type of reasoning along with uh neural networks uh so this has been found in work uh applied to games such as go or atari or puzzle games like soccer band as well as robotics chemical design robotics and so on next slide uh and the particular algorithm that we investigate and replicate is mu0 which is a an extension of alpha zero which learns the model of environment it's a paper from from last year uh which obtains a state of the art results on the verity of games such as chess shogi go or atari and again uh has a as a learned neural network model of the environment uh next slide so i'm going to give a very quick introduction of neural network guided mcts as it is done in u0 to give some kind of context around the issues um so each mcts consists of several simulations that happen one after the other and each simulation consists of three steps the first is you traverse uh the tree you have a you already have a given tree and you're going to traverse the tree from the root to a leaf node using a chosen heuristic and alpha or mu0 the the heuristic is called puct and it basically picks node with the highest score where the score combines two factor which is a policy prior which is kind of a gut feeling of what is the next best action um action values which are derived from the the the tree computation and an exploration bonus which is derived from visit council often you've been in a particular node of the truth so you keep going down next up next you keep going down the tree until you reach a leaf and eventually you reach a leaf node and then you add a node to the tree next so this is called the expansion and this is the case where you're actually called on your network because you need to compute the state transition from the leaf node to the new leaf node uh you need to compute the value for that state and the policy prior for that state and all of these consistent calling different neural networks and eventually you cache of that competition put it in the tree and you follow by the backwards step which to propagate all that information from the new leaf node to all the ancestors into so this is basically how your network mcgs works and next i'll let fabio explain the issues around implementing mcts in jax yeah so hi i'm fabio i'm a research engineer at dmine so um so why is implementing mcts like efficiently uh like a challenging task well in in the in the use case like we have in mind like it is hard because uh you know some of us researchers don't really want to use c plus class day to day and would rather use like higher level languages such as python but uh sticking with plain python performing you know i'm cts in batch can be quite slow um you know making like our research space slower and furthermore you know as you mentioned like vanilla mcds is essentially a sequential algorithm which puts even further constraints on how much we can analyze computation there is of course work in the space but um let's stick with you know like the simplest possible scenario so one possible approach to tackle all of the above is to rely on just in time compilation to somehow try to bridge the gap between interpreted and compiled languages and this is very well aligned with the programming the programming paradigm in jax the next slide so off the bat like if you decide to stick with the with jax like we like we did um there are some you know expected advantages or disadvantages we can foresee in particular what we expect is that uh you know once we manage to jit computation we're just going to be quite performant at least compared to plain python something is really relevant for us especially in rel is to saving costs moving data in and out of accelerators which would happen if you uh broke out to a search engine class classroom cpu uh this in turn allows to uh if you can you know uh jit search you can jit the whole uh acting and the whole learning of our legends which is really relevant for performance um furthermore you know we said that you know if you can stick to something that looks a lot like numpy it's going to be easier to reason and write and modify search components um also you can build on top of the jack's primitives that michaela and others discussed today to write your code for a single uh example and then use for example demon for vectorization um furthermore there is this huge potential of being differentiable all the way through next one on the flip side this is going to be very likely less efficient if you are not touching for example if you want to deploy a trained rl agent on a single environment um and also if you are giving up some of the compute and memory uh of your accelerators for search you're gonna have you know a bit less for plain inference um and this is gonna have a further impact on how deep you can go with your search because your search depth will be limited by the accelerator memory uh furthermore if you're running all the searches in parallel the your performance will be constrained by the slowest instance of your search i want to conclude by just showing a couple of codes deep but this is very high level but i think it will kind of uh nicely reflect what we discussed so far so uh this is like an implementation of the search method over on our mcds jax class as you can see um it is very nice and easy to isolate like the three main components of the algorithm which is the simulate expand and and and backward function that you can easily play with as long as you know they are a nice legitimate uh function um i'm also i also want to highlight that so the the control flow must be expressed using uh like the jax lux uh library for example uh this won't be like a plane python for loop but it's going to be a four eye loop and furthermore if you want to j you need to make sure to have like a fixed uh shape in place so you need to pre-allocate your data structure will contain all your search statistics next one is an example of the expansion function uh where everything that is to do with neural networks can be nicely wrapped into the single function call of the recurrent function as it is in museum and this again makes it very easy to break down and focus only on the few bits you want to do research on and now we can move on to questions and debate can we just ask questions okay so uh so the most obvious questions a question i guess is how does jax compare to pre-existing or established languages like pytorch or tensorflow so what would be the optimal scenario for jax versus pytorch for example what benefits would i have if i chose jax over of a pi torch so we uh uh that's an excellent question um and i i think that part of it would would certainly be flavored by um by a personal perspective but we also have members of the jacks core team here today um and i i think that they would also they that they would be they should have first dibs on answering um matt jj skye do you have an answer i can say uh some things at a high level i guess um we think pie torch is really great and tensorflow is really great too a lot of people are really happy with those things and uh you know we don't want to cast uh you know you know sort of frame things in in like a zero-sum way i think there's things that that um pytorch does that that are you know better than what jax does and they're things that jax does better i'm more familiar with the jack side so you know i can say some things about um what we are trying to do with it and ways in which those things might help you um so i think you know as folks talked about in the in these talks um you know jax makes a bet on uh functional programming that means that um you know to the user things and you know maybe being closer to math and that's nice um but also it means that jax can provide some capabilities uh that work really well around like you know automatic batching like vmap that sort of thing uh you know fancy autodiff um we got you know we started building jacks because we worked in automatic differentiation and so i think jack still has a lot of cool features like forward mode uh and reverse mode they're composable the way it interacts with vmap to be able to compute jacobians and hessians really fast it has some experimental features about exponentially faster very high order auditive so you want to take like 100th order taylor expansions it's quite good at that um you know these are these are things that are like afforded by in in part sort of betting on uh functional programming um and so you get sort of these transformations so i say jax is about the transformations you know maybe maybe first and foremost um maybe you know another thing to say is that it's very compiler oriented i think um mentioned that she likes being able to dig into um uh you know the guts of an optimizer and say there's not going to be an optimizer kernel like an atom update op that's in c plus it's all sort of in user level python and that's because uh you know jax is sort of designed around uh being compiler oriented um and that also like means that you can you can do some things uh some things well like stage out entire parts of your computation we had a you know entry at you know this is kind of the extreme end not necessarily what you see um yourself uh every day but jax had some ml perf entries that set uh world records for training uh some neural networks extremely quickly like uh uh english german uh neural machine translation uh training in like 16 seconds um that's sort of like the extreme end of being uh but it comes out of being sort of compiler oriented like jax is based around being able to like stage out you you giving you the control to stage out parts of your program from python hand them off to this like xla super optimizing compiler that can project things onto not just your you know gpu your single card or you know your hpu's but also like entire tpu supercomputers um anyway so i'd say like at a high level uh you know the bet around compilers you know things around functional programming and just being focused on providing many transformations uh uh and even allowing people experts to sort of extend the system of transformations i think those are things that sort of set jacks apart um from how other systems have worked before whether that is actually useful to you you know it depends maybe what's most useful to you is having a ton of example code out there or you know having a really big community um or you know some some uh there's some workloads that might perform better in other frameworks for example um but hopefully that gives you some flavor along with all the other things folks have talked about here of you know where uh jax could be useful yeah dagger maybe i just wanted to add very briefly before we move to the next question a couple of things so just one aspect is that jax is comparatively very thin is very focused on doing a certain range of things and do them very well and i think that can be quite appealing if you really want to dig into and have like full control about what what is going on the fact that you can really delve into jacks and basically get a good understanding of how everything is working because it's fairly compact and is only doing a focused set of things well i think i found that personally very appealing and it might be something that resonates with other people and and maybe and the second i just wanted to again uh matthew already touched that but this functional programming style that allows to basically expose all their primitives in terms of composable function of function transformation has been like incredibly powerful at least for me i think drew an example to me is all as i worked with him on this project has been one of the the most eye-opening for me experience because we literally just wrote a single agent interacting with a single environment and just just dropped in a couple of bmas a couple of p-maps and suddenly we were running a massive experiment across multiple tpus and multiple environments and multiple agents and so i i just want to shout out again the power of composable function transformations because it's it's really amazing can you can you talk about the debugging experience like what do you get when you inspect things and and also when when jit is involved probably take it yeah um maybe mike's so the the nice thing of debugging in jax is that until you basically turn on jit you are literally just executing uh numpy code literally like it just looks a lot like it and it's it actually executes a lot like it so you can literally pdb into your program and uh step through it line by line and just uh check what then the arrays contain and so on and then you can basically just turn on the jit to make it faster but you can still debug everything in this kind of more friendly python lens so i think uh the debugging experience from that perspective is quite quite nice if i could tack onto that um you know i think you know we're always looking for ways to improve there's ways to you know have better um error messages and and you know track things but on the on the subject of yeah just being able to pop into a python debugger um as long as you're not using jet so even things like uh auto diff um or even if you're using vmap so like something that uses both of those is the jack forward function for forward mode jacobians and if you actually like have a function you're taking a jack forward of and you put a debug statement in it so you like jump into a debugger and you start printing values it'll actually show you uh both the sort of like primal point at which you're linearizing and the entire sort of like basis of tangent vectors uh sort of all the things you're pushing forward together because both vmap and you know this uh and autodiff work sort of in pure python i don't think it's necessarily you know the first thing that you think to do but you do have all your values there you can like poke at them and look at them um so hopefully that helps a little bit on the debugging front awesome so uh latest question then and uh i don't see any hands raised but if if folks want to raise hands instead of asking questions in the chat we can we can also do that um so next question from sure hill um is there a tf data for jax planned by either the jax team or the deepmind team oh and uh and uh sorry apologies for that byron um byron it seems like you have your hand raised uh and are waiting to speak is that correct yeah but you can address the the data one first sorry the question is is there plans for our tf data for jax yes okay well one possible answer to that question is that actually many people do use tf data with jack because you can just kind of go through numpy arrays as the common interchange there so there's like no fancy integration needed you can really just use these like two libraries like any other libraries um so yeah for that reason um and of course this applies to like other data loaders as well or your own custom data loaders so we don't have any immediate plans on the jax team to like create our own data loading library we think many other groups and teams have done a great job of this and we don't know if we can do better um but yeah i'm interested if there's any particular use cases or like features that aren't well served by the existing frameworks for data loading awesome thank you skye and byron um go for it cool so first of all um excuse me i want to you know thank both the jacks team and and the deep team because i've been i've been using sort of haiku to reimplement some things um in arlax in in the last several months i've really uh enjoyed it like i think uh matteo said like just the the functional uh format of it i find it really exciting you know one thing that i don't think anyone mentioned but i found it really nice is serialization of model primers is super super super easy right like if you just have your params out you just save off your brand you blow them back in you're done there's nothing you have to worry about with like fancy loaders or anything like that so i really like that capability as well for checkpoints it's really nice um but my one question i have is again having worked on this uh can you guys just speak to like what um the model is for contributing back to these projects um for instance like i noticed you don't have bi-directional rnn support so i had to roll my own so what is in haiku what is what is the process for contributing that so maybe i can take the questions for the jax ecosystem libraries uh maybe someone for the jax team can can see more about uh contributing to jack's score itself um we we are definitely open to contributions so we have already taken quite a few contributions both in in arlax and uh and optics for instance uh so i think that the the main uh the main message there is open an issue in in on github and let's let's have a conversation there um i i guess uh it's it's it's hard to say without knowing the details uh we we try to um we have an explicit aim to keep each library quite focused and and and similarly to how jax core does one thing and does it well similarly in our ecosystem libraries we try to have you know optics focused very clearly on authorization haiku very clearly on neural networks libraries reinforcement learning very relaxed very clearly it's only enforcement learning so there might be cases where maybe uh there might be valuable code that doesn't fit quite exactly one of the ecosystem libraries in which case might be um we might have to make a call in sense of whether or not it fits the scope of each library but definitely i think opening an issue on github would be a nice way of starting a conversation excellent now is there anything that you wanted to add or skye or tom we like you know we love open source contributions so yeah come over to our good oven and try to make them uh uh yeah plus one everything mateo said we actually just had some uh someone added like windows building support i think and that was incredible that was like a issue number that we had in like the single digits and some open source contributor just uh just contributed it uh so that was fantastic so yeah um come to our issue tracker we're often like overloaded and like time limited and stuff so if we don't respond it's not lack of interest it's just because we're uh scrambling with other things um uh but yeah uh we love we love seeing people uh start discussions and make contributions cool excellent it looks uh next up is it looks like horse all right yeah so i i guess i kind of wanted to ask more about the debugging issue or like the debugging topic so uh like so i i get how the like like in general like with eager mode tracks it's easy to debug because you can just use all your regular tools but in the presence of transformations that aren't uh like grit uh it doesn't still remain easy to debug uh like i guess what happens if you have like a print statement and then you like call grad or vmap and on it uh like yeah i guess you can talk about that that's a great question i that sort of was getting at um but maybe it was sort of hard to explain in words without you know showing the the code but um you know we uh uh jax works by tracing mechanism and that means that it's sort of propagating um extra values uh inside your your python code and so if you print out a you know if you have a function and you print out an intermediate variable in it and you call grad of that function you'll see something print out and jax will sort of like show you the information that it has going on under the hood so uh you know you can use uh pdb to debug things you end up looking a little bit you know behind the scenes like under the covers of what jax is doing so you know in the grad case for example you'll see uh some you know string that prints out when you print a value that shows sort of the primal value at which you're linearizing your function and then something um depending on if you're using forward motor or reverse mode uh possibly about the the tangent value um that's been in the weeds but i guess like you know we a lot of jax you know stays in python jit stages things out but um for the things that stay in python uh you know we try to make the debugging experience as as reasonable as possible so vmap as well as another example you can actually if you print out the value of a an intermediate thing in a vmapped function then it'll actually show you sort of the value it's keeping behind the scenes that has that full batch dimension um even though like your code thinks it's operating on single examples if you print a value jacks will say like actually i'm hiding a whole batch of examples behind the scenes here um but you know all that said those are just you know ways in which we've tried to make the python debugging experience work i'm sure there's lots of ways that it could be improved still uh as well excellent thank you matt and thanks everybody for for the questions these are great um i think the next one um is from robin thanks um does it make sense to mix in a bit of jacks if you have an existing pie torch application or does it really make more sense to just always completely do one or the other i think it depends uh you know on just a lot of situation specific details and what kind of mixture you're talking about but i'm i think actually there's a lot of potential for interesting ways to compose things one you know small example is just being able to hand off gpu backed arrays efficiently between the two without copies and without much overhead that might be a nice feature if you actually want to write some parts of your program in pie charts you have a function that's backed by like pie torch operations you part of it it's backed by jack's operations um so i think both pie church and jack's implement the deal pack standard for uh sort of exchanging uh memory like this and i think there's still some like caveats like um there might be like a gpu synchronization incurred uh or something like this but other than sort of you know engineering uh details that folks are still working on um you know i think there's ways that we can make these things super compatible and that way users just win right uh you know there's no reason to lock people in um if you have reason to build stuff together i don't know exactly what those reasons are or can't think of a situation necessarily off the top of my head but we want to make it so that you can um you know compose these things and you're not you're not limited in that way maybe i would just add that this is this kind of potential integrations might also be held by the fact that jax is quite focused and quite lenient it's not like a huge vertically integrated framework where if you buy into one aspect of jax you need to buy in everything else in a specific neural network library and so on so being quite lenient might open up being quite lean might actually make this kind of integrations easier although it's hard of course to discuss without specific details sky actually in your demo yesterday did you was part of the hugging face demo like loading values from like pie torch checkpoints or something yeah the bert fine-tuning demo involved basically downloading from the hugging face burt checkpoint checkpoint in pi torch and then loading it um into jax but i think that probably went through numpy arrays because it's like downloaded on on the host rather than like through deal pack perfect so we have a horse did you have another question or is your hand still raised no more questions for me no worries are there any questions on the chat that we didn't get a chance to answer it it looks like julian um julian answered ulcers question about performance um avitalin and david shared greatly or james shared great links um robot has his hand raised now excellent robert hi um i had one question just regarding the the future of jax and what we can expect essentially from both the the deepmind internal packages as well as the core team awesome i guess um i can speak to the demand part so like in terms of our libraries i think this model of having low-level libraries of components with incremental buy-in has been a real win for us i think one of the one of the natures of research is that it's broad and unpredictable and moves very quickly i think often um when we try to go the other way and sort of plan in advance you know what what what we think we need to build in time in terms of like a you know press play and run framework um by the time we've gotten there sort of the the field has moved on i think by having these these frameworks components i mean you know perhaps an individual component someone contributes um is not as relevant as it was in the past but it doesn't sort of affect your ability to sort of build others or contribute or expand or whatever it may be so certainly this idea of trying to sort of identify the core features that are required in our research contributed back users as a substrate of sharing research throughout the organization and to the broader community as well and adding more libraries and building libraries on top of libraries i think this is this is working really well for us and something that at least in the short to medium term we intend to continue maybe someone from the jax team can answer the question in regard to jack score yeah i can say something but someone else from jack's core interrupted me if they have a like really good answer i never quite know you know um you know answer a question like that other than to say that like uh the folks who are working on jax uh myself included like we love this project we're like pouring so much of ourselves into it and so you know that to me that's sort of sort of like as much of a guarantee of uh of you know it's not going away um as you can get just because you know we're personally very invested um i guess the jax team has grown and so you know and it's working well things like places like d-mind and google brain and uh research more broadly so um you know i don't think anyone wants to like uh uh i think everyone is is uh the best days of jack's are ahead i guess what i'm saying jax is growing uh the folks working on it are very passionate uh as you've seen and so you know i wouldn't worry at all about um jack's going away or something like that awesome and we are we are at time but we might have time for one more question um and it looks like robin had his hand raised oh yeah this looks amazing i was wondering uh what do you think are the stumbling blocks that new people to jax often hit upon like are there certain things that are more that are often confusing for people or cause issues when they're just getting started we have a great link for that there's something called the jacks the sharp bits um collab on the docks and then i'll i'll i'll link that in the chat in addition i guess i'll add one from the d mine side i think sort of you know as like one of the strengths obviously your jacks is this um this bet on functional programming and sort of you're raising everything into a functional paradigm where things are stateless and you don't need to worry about side effects um i guess as normal like i guess we decided to sort of you know diverge from that slightly with some of our libraries such as haku and the reason was you know we had other considerations we needed to balance we had so many users that are already using um sonnets as their way of defining and reasoning about neural networks within tensorflow and we really wanted to be able to maintain that api so certainly there's some cognitive overhead in you know sort of writing something that looks object oriented which is sort of converted for you into functions and sometimes it takes some users um i uses some time to ramp up to but i mean that was just sort of the the trade-off that we decided worked worked for us as an organization um certainly there are other jacks neural network libraries that take different approaches that other people find work better for them and you know in the spirit of incremental buy and cross-compatibility you know you're free to use any part of any of those with any of our libraries as well i think for the most part that's been working for people thank you i would add maybe uh a common bit that was also uh mentioned on the chat uh is the pseudo-random number generation um that is something that uh as a first newcomer to jax might might be a little bit uh surprising if you're used to numpy where you don't have to reason about the state of the random number generator almost at all i would say though that while it can be surprising at first it actually pays off in the long run so once you get used to actually this way of of dealing with random number generation i i think and it is in some sense at least personally i kind of ended up with the conclusion that this was actually the right way all along in some sense and that it there is also a very nice um document in in the jacks documentation about the the prng design explaining the reasons and the motivations for for the the way jax deals with random numbers i think that is a great read i think as a newcomer if you're maybe a bit puzzled at the beginning by random numbers just read that doc because it will it will think clarify a lot yeah i wholeheartedly agree with what matteo is saying like once you've used jack's rngs like all other rngs that maintain their own stay silent they become very scary all of a sudden so definitely cool so um so we we are at time um the thank you all so much for your questions i there might be one more slide left uh is is this correct mateo haha there we go cool so so thank you everybody for coming um absolutely please make sure to share your jax projects um on social media whatever your favorite flavor of social media is um github is also social um and it is my favorite social uh hangout place um but anything that you create please tag it with jack's ecosystem so that we can make sure to see it and to share it and if you're interested in contributing there are many issues that are good for first contributors on the jax repo many places to get involved and we look forward to seeing what you create thank you so much everyone for attending thank you to jack's core and for all of the jax ecosystem teams for presenting your work and for getting us all very excited i know i am significantly more excited somehow than i was even an hour ago so thank you everybody have a great day have a great nurips and see you soonit off to mateo mateo do you want to take it over yeah thank you very much paige so today i'll be presenting with my friends and colleagues so dave mihaila and jun fabio and anteo and basically our objective is to give you an overview of the ecosystem and the tools and libraries that we've been building around jax and to try to give you a feel and an understanding of why we we as a group and as a company are quite excited about jax and also hopefully help answer some questions from you about uh maybe what are your doubts what are your concerns and whether or not jax could be also useful for your own research and for your own projects so let me start maybe with a with a very brief introduction to to jax and to at a very like bird's eye why it might be an exciting framework to work in and build machine learning projects and research on so at its core uh many of you will be already familiar but uh let me just repeat really the fundamentals at its core jax is a python library that is designed for high performance numerical computing and it among it's like key ingredients it basically supports three main areas one is differentiation so forward and reverse mode automatic differentiation of arbitrary numerical functions and here we have a whole lot of like jack's primitives like grad hessian jack jacobian forward jacob rev that all allow you to basically implement and exploit the automatic differentiation capabilities of jax the second broad family of primitives that jax offers to its users are related to vectorization and automatic vectorization so in a lot of machine learning we rely on a single instruction multiple data style of programming where we typically want to apply the same type of transformations to a whole batch of identical data so we might for instance be computing a loss over a big batch of the input and output samples and the jax basically makes life a lot easier to researchers and and practitioner of machine learning by exposing some very neat and simple abstraction to simplifying single instruction multiple data programming these are mostly v-map and p-map and we're gonna hear some some examples today of how these can can really make the difference in writing clean but expressive code and finally the third big area inside jax is jit compilation so just in time compilation and basically jax is built on top of xla and uses xla to just in time compile your code and make both faster your cpu code if you're running a cpu but also give you transfer parent gpu or cloud cpu and acceleration the the important feature of jax like if you are to basically take away just one message from what is jax is that all these these these functions and these abstractions are implemented as composable program transformation so what what does this actually mean let's let's give a very very concrete simple example and consider a numerical function that just squares an input and sums to the second argument of that same function you could brighten python this way and x and y could be for instance numpy arrays the value of this function is of course evaluated like any python function you just call the function pass some inputs and and the the first input will be squared and sum to the second one what would computing a gradient look like with jax well the neat thing of jax is that the gradient of this function is comp is also a function so you just call jax.grad on a function and you get back something that is still a python object it's still a python function but that if you call now computes the derivative of the original function in a very neat and transparent way and importantly these uh this transformation is composable so you could for instance compute the second order gradient by calling grad of grad of the original function and this again gives you just the python function but then if you pass x and y gives you the second order gradient in those two values of x and y similarly you can mix and match this with other program transformations like compilation so if you just compute if you just call jit of grad of grad of the original function this gives you a compiled second order gradient function that again will you can just call as any other python function but the first time you call it will trace the code and compile it using xla and then the second time you're all following times you will call the same function it will execute a pre-compiled xla code and therefore be much faster and of course this doesn't stop here all the jack's primitives are neatly composable so you could for instance batching your your function you just write it as if you were just dealing with single examples but if you remap that function again you get a function that now expects a batch of inputs and computes the same function of the on the entire batch and finally if you want to then execute these batched compile second order gradient calculation in parallel and multiple multiple accelerators for instance this could be multiple gpus you just need to chain one more now you p-map a v-mapped legit grand of a grad of a function and again this is just a python function that you can call on your inputs and i hope this gives you a flavor of how composable gradient program transformations enable jax to expose a very thin set of primitives that are easy to understand but that you can then combine in very rich and very powerful programs that can support quite quite powerful use cases and before i i leave to my colleagues to delve into into more details i just want to conclude with a couple of final remarks on why jax is so convenient and i think there are a few things to remember one is of course this composable nature of all its uh abstractions but it's also good to remember that all of the numerical functions are written in with a with a syntax that is fully consistent with numpy so it's literally you're just literally just writing numpy code but then you can transform it using these program transformations this means that jax can be quite familiar even when starting because python and numpy are widely used and also in a couple of of points i want to highlight is one day is not a vertically integrated framework like like many other ml frameworks it's really focused on getting this core bit of numerical computing right and then it's it still provides you with a very rich ecosystem built around it and a community around it that can give you the um you can make other like other things and classical things like building neural networks easy for you and this is what david is going to be uh talking about just now hi um yeah thanks a lot mateo that's a great overview um so my name is david i'm a research engineer at deepmind and i'm here today to talk a little bit about our deep mindjacks ecosystem which is essentially a collection of libraries that we've been developing together over the past 18 months or so now that build on top of jacks as matteo just been explaining um and sort of added additional functionalities with functionality rather than specific machine learning research thanks likely too um so why an ecosystem so deepmind researchers have had a lot of great initial success using jax so this is normally a nervous hosted event many of the papers that we're presenting in europe this year um use jacks under the hood for to produce the results um as engineers and researchers at deepmind we're constantly asking the question how can we continue to support and accelerate this great work and this early success that we've seen um there are a few considerations that go into this the first is mateo which is saying is the jax is not a vertically integrated machine learning framework this is the strength it does one thing very it does a core set of things very very well and what we want to do is um what we want to do is build on top of this and continue to build our own libraries on top to meet the specific machine learning research and our particular needs um whatever we build needs to support our rapidly evolving requirements you know for instance a specific focus on reinforcement learning among other things and where possible we want to strive for consistency and compatibility with frameworks and tools we've released in the past for instance the compatible of tensorflow things like sonos and truffle um the solution we've arrived that is is an ecosystem which is basically a collection of libraries of reusable and unopinionated jax components um each library is intended to do one specific thing well and importantly supports incremental buy-in so for instance we're going to show you libraries for neural networks libraries for optimization for reinforcement learning from day one we've made sure to develop these in a way such that you can pick and choose the pieces you want without being locked into any decisions about other libraries or other tools from within the same library and of course where possible open sourcing everything to enable research sharing and reproducibility of our results that we build on top of the jacks ecosystem thanks a lot i'm quickly going to go over three examples of libraries we've built although there's others and i encourage you all to check them out now for more details we we posted a blog recently on the deepmind.com website that goes through this in more detail um so haiku is our jax library for neural networks so as mateo is highlighting much of the strength of jax comes from essentially this functional programming paradigm whereby we have stateless functions and we can compose function transforms so it's just a gravity map of whatever returns a function that function can be statelessly applied to to whatever to some numpy array or some jackson umpi array um often some abstractions that are common to machine learning things like neural network trainable parameters and state uh fit the object oriented paradigm potentially better or at least is how people are used to thinking about this so higher kind of tries to bridge the gap between the object-oriented world and the functional world of jax it essentially provides tools by which you can take um stateful objects such as neural networks and convert these into initialization and application functions that can then be used directly in jacks our researchers have had great success on trivially porting previous tensorflow results such as um the the alphas here for instance into into jacks and it's a widely used and mature framework both within deepmind but um more widely within google and the public community as well next slide opdax so optax is for optimization so much like neural networks it often makes sense to think about neural networks in an object-oriented paradigm as having state um optdx is essentially a gradient processing optimization library that's core provides sort of a zoo of common optimizers that people know and love things like sgd art and momentum whatever it may be and these are importantly these are defined as this chain operation of simple optics primitives so sort of fundamental gradient processing primitives they're common to many of the optimizers and many other things as well so optics provides these for simple out-of-the-box implementations changing these components to give us build optimizers and gives the user a lot of functionality to do this themselves in addition it provides a lot of useful utilities that allow the trivial application of gradient-based updates to neural network parameters i'm not just haiku parameters for instance using out of the jacks neural network library but sort of true to the spirit of incremental buy-in and cross-compatibility um all or at least most of the popular jackson neural network libraries you may be familiar with should be compatible with objects anything that represents the parameters as a tree of numpy arrays essentially again widely used within deepmind with a growing user base outside next slide please so as you probably know deepmind cares quite a lot about reinforcement learning i guess one of the one of the complexities of reinforcement learning is it's very easy to get things wrong for instance if you place a stop gradient in the wrong place which can be very simple to do the entire experiment is going to fall apart um the idea of relax is essentially to provide a library of trusted battle tested reinforcement learning components so not even algorithms with sub-components of algorithms for instance on the right-hand side you can see how a q-learning primitive and a square loss primitive can be combined to define a simple loss function that you might want to use and essentially this provides a substrate by which our researchers can use components they trust and have algorithms that they trust in addition to quickly contributing their own ideas back to be shared amongst the broader community um a lot of the emphasis we face on this library has been on readability so for instance what does this mean now we use unicode everywhere so if you have a look at our doctrines you know we're not just scared to reproduce our doctrines to look like textbook examples but also not being afraid to repeat ourselves as well if we want people who land on a function for td error to see where the td error is not need to chase this down three or four different functions next slide so i don't have time today to go over all the libraries that we've released as part of our ecosystem but there are others for instance the giraffe library for graph neural networks which we released last week as well as our chex library for testing and reliability here are some links here i encourage you to check them out in addition to this world view that we're telling you today about an ecosystem i suppose there's an additional shell around the outside of this which are frameworks that are then built on top of these libraries um some examples from deepmind are jacksline and acne for supervised learning in rl i encourage you to check these out and in addition to the work from deepmind we've been talking about today there's a lot of really great work in building being built on top of the jacks ecosystem both being google and in the public more generally so for the rest of this presentation i'm going to hand over people who are going to talk more about their sort of research results have been based on jax and jack's ecosystem are starting with michaela who's going to talk about gans and generative models um thanks abel so i hope everyone can hear me and i'm super excited to talk uh today about what i think makes jax amazing for generative models and specifically gans next slide please um and the reason i chose gans is because they're a bit different than the standard view that we might have of other generative models of this paradigm of one model one loss in the gen case we have two models we have the players the generator which is producing data and the discriminator which learns to distinguish between this generated data and the real data and the goal of the discriminator is to be really good at distinguishing between these two and the goal of the generator is to basically be to fool this discriminator into thinking that the data that it generates is real and um the the way to to do this is either via a zero-sum game or via other types of losses they don't need to be necessarily connected as long as the underlying principle is this we have one player that learns to distinguish and one player that learns to to full and hence the name adversarial next slide please now if we want to implement this in in jacks one thing that we notice straight off the bat is that when when you uh implement again you might want to do multiple discriminator updates for each generator update um or vice versa often more of the discriminator and this is super trivial in jax you just have to write the python for loop for each of the players we all know how to do that and inside this uh python for loop for each of the players you decide to update the parameters of that player and the really neat things about this that also mateo highlighted earlier is that we now can get gradient functions so not gradient values but we can ask jax well jax can i please have the gradient function which is the gradient of the discriminator loss with respect to the discriminate parameters and then i can evaluate that function at the current values of both of the discriminator parameter generator parameter and what other data batch i might have and so on and crucially here there's no need for any stop gradient on the parameters of the generator even though the value of that loss will depend on these parameters because we just ask jax for the gradient with respect to the first argument which is the discriminator parameter so this is very nice very easy to reason about and almost like a pseudo code if i if you ask me to write a pseudo again algorithm pseudocode it will look very similar to this and i think this is really really nice now once we have the gradients we can pass them to an optimizer like an octagon optimizer that they've talked about getting an update getting our new discriminator parameters and then doing exactly the same thing for the generator this time only that we use a different loss the generator loss and we get a new gradient function the generator loss function gradient with respect to the generator parameters so this is all uh very neat next slide please but one thing that i want to highlight here we now have direct access to gradients so for example if i want to see how the norm of the gradient at different layers looks like i can easily do that i can just add one or two lines of numpy um code that i'm familiar with and i can get i can get those statistics and i think this is very useful in terms of allowing us researchers to build intuition about what our models are doing in a very simple way next slide please something else that i think is maybe specific to gans that can that jax highlights very nicely is it gives you more control and it makes you think about the kind of decisions that you're trying to make so one type of decision for example is well when i update the discriminator what should i do about the generator state and what i mean by state is not the generator parameters but for example let's say the generator batch norm statistics do or do i or do i not want to update the statistics and let's leave aside for now whether you actually want to do this or not in gans the point here is that when you implement uh ants in checks it makes it forces you to think about this and it forces you to make the right decision so briefly looking at um at this quick implementation that i made up the discriminator loss we see that we have a four pass through the generator this returns us a new generator state and now when i return from this discriminator loss i can say well i want to use this new generator state in conjunction with the new discriminator state or next slide please i can just decide to ignore that and return the state of the generator that was given to me very very easy again i have to think what i want and getting to what i want is just a matter of changing a few lines of code here too instead of digging somewhere deep for something that i might not have access to next slide please and last but not least i think i i just want to say this one more time of this functional approach making making code very close to math especially in generative modeling when we often have a lot of distributions gradients of things that depend on distributions makes it much more easy to to reason about and i'm just going to highlight two things that mateo talked about earlier which is being able to vectorize functions and computing jacobians so if i want to compute um let's say gradients with respect to some parameters of an expectation of a distribution that depends on these parameters this is not revealed we can use the score function estimator but that has higher variance and if we can for some distributions we can use their parameterization trick so this is just a trick that allows us to rewrite that expectation with respect to another distribution that doesn't depend on these parameters and push the parameters inside now even with this new form now we see that we still have an expectation so we have to compute some functions over multiple samples from this expectation what we can do is we can expect the user to just always pass us in vectorize functions which can be okay if it's only neural networks but sometimes it can be something else we don't only have to compute gradients with respect to neural network functions or we can do it for them and jax really allows us to do that by just saying well i'm gonna vectorize the function for you you can pass in anything that you want and i specifically find this very very useful for tests because even though we might use neural networks in our um in our experiments often we want to test for a very complicated nonlinear function and we don't want to spend the time vectorizing all of our uh our functions in our tests and again if i don't want to compute a gradient and i want to compute a jacobian it's just a matter of changing a few characters and i get uh get what i want so i think this is really really um convenient and very easy to reason about and that's it for me and i'm going to give it up to june who's going to talk about meta gradients thanks um hi everyone i'm jun i'm a research scientist at deepmind and i'm going to briefly talk about how i use jacks for my recent work next slide please so in this work on discovering reinforcement learning algorithms we try to meta learn a reinforcement learning update rule from a distribution of agents and environments like this figure but there were several technical challenges because of the unique problem set up over here thanks so first we wanted to simulate many independent learning agents where each agent is interacting with its own set of environments but this is already quite unusual because normally in reinforcement learning we just consider one learning agent but here we wanted to simulate multiple lifetimes of agents simultaneously next place and at the same time we wanted to apply the same update rule which is the meta learner to all learning agents in a completely synchronized way next please and also we wanted to calculate meta gradient over this asian update procedure which requires calculating the second order gradient and finally if you we wanted to massively scale up this approach by increasing the number of learning agents without introducing much extra computational cost so these were quite challenging so to address these challenges we um we used jacks and jacks actually allowed us to easily handle all these challenges in the next slide we i'm going to briefly describe how we implemented this system using jets and at the beginning we first implemented every environment in jax so that we can apply pre-map and pmf later and as you can see from this figure we implemented a single update rule and single asian and single jacks environment interaction which was quite simple and if you go to the next slide um we and then we added a feedback to implement multiple jacks environments for one learning agent and in the next slide we added another vmap at the outer scope to implement multiple learning agents like this figure where each agent still interacts with its own set of environments and finally we added pmap to implement multiple copies of the same computation graph and then distribute this across multiple tpu cores like this figure with a shared update guru so in this figure each tpu core essentially has its own set of agents and its own set of environments but they all share the same update rule and they all calculate the meta gradient and in a completely synchronized way and then perform meta update if you go to the next slide and this is the pseudo code of our algorithm so here the top part implements the agent update and the bottom part implements the meta gradient calculation and the yellow part is the free maps and pmap that i mentioned in the previous slides so if we just remove these remaps and p-maps then it essentially becomes the single asian optical environment interaction but by just adding a few b maps and peanuts like this figure we can easily convert the the simplest single interaction implementation to the massively parallel system that i showed in the previous slide so in the actual experiment we used the 16 core ppu and using this tpu we could simulate 1 000 parallel learning agents and 60 000 parallel environments with one global shared update rule and we were able to process more than 3 million stats per second using 16 core gpu so next slide please to just summarize we had several interesting challenges because of this unusual problem setup but by using jackson tpu we were actually able to handle all these challenges quite easily without putting much engineering effort which was quite nice for us and also we got quite interesting results out of this project so if you are interested in this project please check out the paper and also come by the postal session tomorrow yeah that's all i have thanks i'm going to hand over fabio and d uh hi everyone so i'm theo i'm a research scientist at dmind and i'm going to present uh some work we're doing on search and model bazaar with fabio um next okay so here we're shifting gears a little bit so so far we have looked at applications of jacks which uh crucially leverages gradient computation capabilities uh and you may ask you know is this the main use of jax is this gradient uh kind of base computation the main thing that we care about so here we showcase another application where jax enables fast research iteration that uses gradient but it's not the core of the compute and it's basically multi-colored research in a model-based r setting as seen in alpha zero and mu zero so we are training rl agents um which uh plan effectively using an internal mode of the world create the form of a plan which is a sequence of action that they optimize over and then both use that plan to act in the real world as well as to update uh policy priorities that they use to help to guide the plan and there are challenges in doing so um the first is that in a typical kind of neuroline neural network guided research you have a tight integration of the control logic in your network machinery and this is actually quite tricky to debug um there's challenges around scalability and parallelism because uh research is inherently a sequential algorithm five you will mention more on this later um and as typical for model-based rl setups there is a lot of issues around the data pipelines a lot of not so much issues as much as design choices user replay the share of synthetic data versus real data how to use data for policy learning versus model learning and so on and so you need a framework which enables you to quickly test loads of different ideas and that kind of setup next time so next yeah so and next again thank you um so sorry just one one before just a picture with one so um in the last few years we've seen kind of an explosion of uh a space which finds itself at the intersection of search based algorithm which evaluates a different type of solution using kind of more discrete type of reasoning along with uh neural networks uh so this has been found in work uh applied to games such as go or atari or puzzle games like soccer band as well as robotics chemical design robotics and so on next slide uh and the particular algorithm that we investigate and replicate is mu0 which is a an extension of alpha zero which learns the model of environment it's a paper from from last year uh which obtains a state of the art results on the verity of games such as chess shogi go or atari and again uh has a as a learned neural network model of the environment uh next slide so i'm going to give a very quick introduction of neural network guided mcts as it is done in u0 to give some kind of context around the issues um so each mcts consists of several simulations that happen one after the other and each simulation consists of three steps the first is you traverse uh the tree you have a you already have a given tree and you're going to traverse the tree from the root to a leaf node using a chosen heuristic and alpha or mu0 the the heuristic is called puct and it basically picks node with the highest score where the score combines two factor which is a policy prior which is kind of a gut feeling of what is the next best action um action values which are derived from the the the tree computation and an exploration bonus which is derived from visit council often you've been in a particular node of the truth so you keep going down next up next you keep going down the tree until you reach a leaf and eventually you reach a leaf node and then you add a node to the tree next so this is called the expansion and this is the case where you're actually called on your network because you need to compute the state transition from the leaf node to the new leaf node uh you need to compute the value for that state and the policy prior for that state and all of these consistent calling different neural networks and eventually you cache of that competition put it in the tree and you follow by the backwards step which to propagate all that information from the new leaf node to all the ancestors into so this is basically how your network mcgs works and next i'll let fabio explain the issues around implementing mcts in jax yeah so hi i'm fabio i'm a research engineer at dmine so um so why is implementing mcts like efficiently uh like a challenging task well in in the in the use case like we have in mind like it is hard because uh you know some of us researchers don't really want to use c plus class day to day and would rather use like higher level languages such as python but uh sticking with plain python performing you know i'm cts in batch can be quite slow um you know making like our research space slower and furthermore you know as you mentioned like vanilla mcds is essentially a sequential algorithm which puts even further constraints on how much we can analyze computation there is of course work in the space but um let's stick with you know like the simplest possible scenario so one possible approach to tackle all of the above is to rely on just in time compilation to somehow try to bridge the gap between interpreted and compiled languages and this is very well aligned with the programming the programming paradigm in jax the next slide so off the bat like if you decide to stick with the with jax like we like we did um there are some you know expected advantages or disadvantages we can foresee in particular what we expect is that uh you know once we manage to jit computation we're just going to be quite performant at least compared to plain python something is really relevant for us especially in rel is to saving costs moving data in and out of accelerators which would happen if you uh broke out to a search engine class classroom cpu uh this in turn allows to uh if you can you know uh jit search you can jit the whole uh acting and the whole learning of our legends which is really relevant for performance um furthermore you know we said that you know if you can stick to something that looks a lot like numpy it's going to be easier to reason and write and modify search components um also you can build on top of the jack's primitives that michaela and others discussed today to write your code for a single uh example and then use for example demon for vectorization um furthermore there is this huge potential of being differentiable all the way through next one on the flip side this is going to be very likely less efficient if you are not touching for example if you want to deploy a trained rl agent on a single environment um and also if you are giving up some of the compute and memory uh of your accelerators for search you're gonna have you know a bit less for plain inference um and this is gonna have a further impact on how deep you can go with your search because your search depth will be limited by the accelerator memory uh furthermore if you're running all the searches in parallel the your performance will be constrained by the slowest instance of your search i want to conclude by just showing a couple of codes deep but this is very high level but i think it will kind of uh nicely reflect what we discussed so far so uh this is like an implementation of the search method over on our mcds jax class as you can see um it is very nice and easy to isolate like the three main components of the algorithm which is the simulate expand and and and backward function that you can easily play with as long as you know they are a nice legitimate uh function um i'm also i also want to highlight that so the the control flow must be expressed using uh like the jax lux uh library for example uh this won't be like a plane python for loop but it's going to be a four eye loop and furthermore if you want to j you need to make sure to have like a fixed uh shape in place so you need to pre-allocate your data structure will contain all your search statistics next one is an example of the expansion function uh where everything that is to do with neural networks can be nicely wrapped into the single function call of the recurrent function as it is in museum and this again makes it very easy to break down and focus only on the few bits you want to do research on and now we can move on to questions and debate can we just ask questions okay so uh so the most obvious questions a question i guess is how does jax compare to pre-existing or established languages like pytorch or tensorflow so what would be the optimal scenario for jax versus pytorch for example what benefits would i have if i chose jax over of a pi torch so we uh uh that's an excellent question um and i i think that part of it would would certainly be flavored by um by a personal perspective but we also have members of the jacks core team here today um and i i think that they would also they that they would be they should have first dibs on answering um matt jj skye do you have an answer i can say uh some things at a high level i guess um we think pie torch is really great and tensorflow is really great too a lot of people are really happy with those things and uh you know we don't want to cast uh you know you know sort of frame things in in like a zero-sum way i think there's things that that um pytorch does that that are you know better than what jax does and they're things that jax does better i'm more familiar with the jack side so you know i can say some things about um what we are trying to do with it and ways in which those things might help you um so i think you know as folks talked about in the in these talks um you know jax makes a bet on uh functional programming that means that um you know to the user things and you know maybe being closer to math and that's nice um but also it means that jax can provide some capabilities uh that work really well around like you know automatic batching like vmap that sort of thing uh you know fancy autodiff um we got you know we started building jacks because we worked in automatic differentiation and so i think jack still has a lot of cool features like forward mode uh and reverse mode they're composable the way it interacts with vmap to be able to compute jacobians and hessians really fast it has some experimental features about exponentially faster very high order auditive so you want to take like 100th order taylor expansions it's quite good at that um you know these are these are things that are like afforded by in in part sort of betting on uh functional programming um and so you get sort of these transformations so i say jax is about the transformations you know maybe maybe first and foremost um maybe you know another thing to say is that it's very compiler oriented i think um mentioned that she likes being able to dig into um uh you know the guts of an optimizer and say there's not going to be an optimizer kernel like an atom update op that's in c plus it's all sort of in user level python and that's because uh you know jax is sort of designed around uh being compiler oriented um and that also like means that you can you can do some things uh some things well like stage out entire parts of your computation we had a you know entry at you know this is kind of the extreme end not necessarily what you see um yourself uh every day but jax had some ml perf entries that set uh world records for training uh some neural networks extremely quickly like uh uh english german uh neural machine translation uh training in like 16 seconds um that's sort of like the extreme end of being uh but it comes out of being sort of compiler oriented like jax is based around being able to like stage out you you giving you the control to stage out parts of your program from python hand them off to this like xla super optimizing compiler that can project things onto not just your you know gpu your single card or you know your hpu's but also like entire tpu supercomputers um anyway so i'd say like at a high level uh you know the bet around compilers you know things around functional programming and just being focused on providing many transformations uh uh and even allowing people experts to sort of extend the system of transformations i think those are things that sort of set jacks apart um from how other systems have worked before whether that is actually useful to you you know it depends maybe what's most useful to you is having a ton of example code out there or you know having a really big community um or you know some some uh there's some workloads that might perform better in other frameworks for example um but hopefully that gives you some flavor along with all the other things folks have talked about here of you know where uh jax could be useful yeah dagger maybe i just wanted to add very briefly before we move to the next question a couple of things so just one aspect is that jax is comparatively very thin is very focused on doing a certain range of things and do them very well and i think that can be quite appealing if you really want to dig into and have like full control about what what is going on the fact that you can really delve into jacks and basically get a good understanding of how everything is working because it's fairly compact and is only doing a focused set of things well i think i found that personally very appealing and it might be something that resonates with other people and and maybe and the second i just wanted to again uh matthew already touched that but this functional programming style that allows to basically expose all their primitives in terms of composable function of function transformation has been like incredibly powerful at least for me i think drew an example to me is all as i worked with him on this project has been one of the the most eye-opening for me experience because we literally just wrote a single agent interacting with a single environment and just just dropped in a couple of bmas a couple of p-maps and suddenly we were running a massive experiment across multiple tpus and multiple environments and multiple agents and so i i just want to shout out again the power of composable function transformations because it's it's really amazing can you can you talk about the debugging experience like what do you get when you inspect things and and also when when jit is involved probably take it yeah um maybe mike's so the the nice thing of debugging in jax is that until you basically turn on jit you are literally just executing uh numpy code literally like it just looks a lot like it and it's it actually executes a lot like it so you can literally pdb into your program and uh step through it line by line and just uh check what then the arrays contain and so on and then you can basically just turn on the jit to make it faster but you can still debug everything in this kind of more friendly python lens so i think uh the debugging experience from that perspective is quite quite nice if i could tack onto that um you know i think you know we're always looking for ways to improve there's ways to you know have better um error messages and and you know track things but on the on the subject of yeah just being able to pop into a python debugger um as long as you're not using jet so even things like uh auto diff um or even if you're using vmap so like something that uses both of those is the jack forward function for forward mode jacobians and if you actually like have a function you're taking a jack forward of and you put a debug statement in it so you like jump into a debugger and you start printing values it'll actually show you uh both the sort of like primal point at which you're linearizing and the entire sort of like basis of tangent vectors uh sort of all the things you're pushing forward together because both vmap and you know this uh and autodiff work sort of in pure python i don't think it's necessarily you know the first thing that you think to do but you do have all your values there you can like poke at them and look at them um so hopefully that helps a little bit on the debugging front awesome so uh latest question then and uh i don't see any hands raised but if if folks want to raise hands instead of asking questions in the chat we can we can also do that um so next question from sure hill um is there a tf data for jax planned by either the jax team or the deepmind team oh and uh and uh sorry apologies for that byron um byron it seems like you have your hand raised uh and are waiting to speak is that correct yeah but you can address the the data one first sorry the question is is there plans for our tf data for jax yes okay well one possible answer to that question is that actually many people do use tf data with jack because you can just kind of go through numpy arrays as the common interchange there so there's like no fancy integration needed you can really just use these like two libraries like any other libraries um so yeah for that reason um and of course this applies to like other data loaders as well or your own custom data loaders so we don't have any immediate plans on the jax team to like create our own data loading library we think many other groups and teams have done a great job of this and we don't know if we can do better um but yeah i'm interested if there's any particular use cases or like features that aren't well served by the existing frameworks for data loading awesome thank you skye and byron um go for it cool so first of all um excuse me i want to you know thank both the jacks team and and the deep team because i've been i've been using sort of haiku to reimplement some things um in arlax in in the last several months i've really uh enjoyed it like i think uh matteo said like just the the functional uh format of it i find it really exciting you know one thing that i don't think anyone mentioned but i found it really nice is serialization of model primers is super super super easy right like if you just have your params out you just save off your brand you blow them back in you're done there's nothing you have to worry about with like fancy loaders or anything like that so i really like that capability as well for checkpoints it's really nice um but my one question i have is again having worked on this uh can you guys just speak to like what um the model is for contributing back to these projects um for instance like i noticed you don't have bi-directional rnn support so i had to roll my own so what is in haiku what is what is the process for contributing that so maybe i can take the questions for the jax ecosystem libraries uh maybe someone for the jax team can can see more about uh contributing to jack's score itself um we we are definitely open to contributions so we have already taken quite a few contributions both in in arlax and uh and optics for instance uh so i think that the the main uh the main message there is open an issue in in on github and let's let's have a conversation there um i i guess uh it's it's it's hard to say without knowing the details uh we we try to um we have an explicit aim to keep each library quite focused and and and similarly to how jax core does one thing and does it well similarly in our ecosystem libraries we try to have you know optics focused very clearly on authorization haiku very clearly on neural networks libraries reinforcement learning very relaxed very clearly it's only enforcement learning so there might be cases where maybe uh there might be valuable code that doesn't fit quite exactly one of the ecosystem libraries in which case might be um we might have to make a call in sense of whether or not it fits the scope of each library but definitely i think opening an issue on github would be a nice way of starting a conversation excellent now is there anything that you wanted to add or skye or tom we like you know we love open source contributions so yeah come over to our good oven and try to make them uh uh yeah plus one everything mateo said we actually just had some uh someone added like windows building support i think and that was incredible that was like a issue number that we had in like the single digits and some open source contributor just uh just contributed it uh so that was fantastic so yeah um come to our issue tracker we're often like overloaded and like time limited and stuff so if we don't respond it's not lack of interest it's just because we're uh scrambling with other things um uh but yeah uh we love we love seeing people uh start discussions and make contributions cool excellent it looks uh next up is it looks like horse all right yeah so i i guess i kind of wanted to ask more about the debugging issue or like the debugging topic so uh like so i i get how the like like in general like with eager mode tracks it's easy to debug because you can just use all your regular tools but in the presence of transformations that aren't uh like grit uh it doesn't still remain easy to debug uh like i guess what happens if you have like a print statement and then you like call grad or vmap and on it uh like yeah i guess you can talk about that that's a great question i that sort of was getting at um but maybe it was sort of hard to explain in words without you know showing the the code but um you know we uh uh jax works by tracing mechanism and that means that it's sort of propagating um extra values uh inside your your python code and so if you print out a you know if you have a function and you print out an intermediate variable in it and you call grad of that function you'll see something print out and jax will sort of like show you the information that it has going on under the hood so uh you know you can use uh pdb to debug things you end up looking a little bit you know behind the scenes like under the covers of what jax is doing so you know in the grad case for example you'll see uh some you know string that prints out when you print a value that shows sort of the primal value at which you're linearizing your function and then something um depending on if you're using forward motor or reverse mode uh possibly about the the tangent value um that's been in the weeds but i guess like you know we a lot of jax you know stays in python jit stages things out but um for the things that stay in python uh you know we try to make the debugging experience as as reasonable as possible so vmap as well as another example you can actually if you print out the value of a an intermediate thing in a vmapped function then it'll actually show you sort of the value it's keeping behind the scenes that has that full batch dimension um even though like your code thinks it's operating on single examples if you print a value jacks will say like actually i'm hiding a whole batch of examples behind the scenes here um but you know all that said those are just you know ways in which we've tried to make the python debugging experience work i'm sure there's lots of ways that it could be improved still uh as well excellent thank you matt and thanks everybody for for the questions these are great um i think the next one um is from robin thanks um does it make sense to mix in a bit of jacks if you have an existing pie torch application or does it really make more sense to just always completely do one or the other i think it depends uh you know on just a lot of situation specific details and what kind of mixture you're talking about but i'm i think actually there's a lot of potential for interesting ways to compose things one you know small example is just being able to hand off gpu backed arrays efficiently between the two without copies and without much overhead that might be a nice feature if you actually want to write some parts of your program in pie charts you have a function that's backed by like pie torch operations you part of it it's backed by jack's operations um so i think both pie church and jack's implement the deal pack standard for uh sort of exchanging uh memory like this and i think there's still some like caveats like um there might be like a gpu synchronization incurred uh or something like this but other than sort of you know engineering uh details that folks are still working on um you know i think there's ways that we can make these things super compatible and that way users just win right uh you know there's no reason to lock people in um if you have reason to build stuff together i don't know exactly what those reasons are or can't think of a situation necessarily off the top of my head but we want to make it so that you can um you know compose these things and you're not you're not limited in that way maybe i would just add that this is this kind of potential integrations might also be held by the fact that jax is quite focused and quite lenient it's not like a huge vertically integrated framework where if you buy into one aspect of jax you need to buy in everything else in a specific neural network library and so on so being quite lenient might open up being quite lean might actually make this kind of integrations easier although it's hard of course to discuss without specific details sky actually in your demo yesterday did you was part of the hugging face demo like loading values from like pie torch checkpoints or something yeah the bert fine-tuning demo involved basically downloading from the hugging face burt checkpoint checkpoint in pi torch and then loading it um into jax but i think that probably went through numpy arrays because it's like downloaded on on the host rather than like through deal pack perfect so we have a horse did you have another question or is your hand still raised no more questions for me no worries are there any questions on the chat that we didn't get a chance to answer it it looks like julian um julian answered ulcers question about performance um avitalin and david shared greatly or james shared great links um robot has his hand raised now excellent robert hi um i had one question just regarding the the future of jax and what we can expect essentially from both the the deepmind internal packages as well as the core team awesome i guess um i can speak to the demand part so like in terms of our libraries i think this model of having low-level libraries of components with incremental buy-in has been a real win for us i think one of the one of the natures of research is that it's broad and unpredictable and moves very quickly i think often um when we try to go the other way and sort of plan in advance you know what what what we think we need to build in time in terms of like a you know press play and run framework um by the time we've gotten there sort of the the field has moved on i think by having these these frameworks components i mean you know perhaps an individual component someone contributes um is not as relevant as it was in the past but it doesn't sort of affect your ability to sort of build others or contribute or expand or whatever it may be so certainly this idea of trying to sort of identify the core features that are required in our research contributed back users as a substrate of sharing research throughout the organization and to the broader community as well and adding more libraries and building libraries on top of libraries i think this is this is working really well for us and something that at least in the short to medium term we intend to continue maybe someone from the jax team can answer the question in regard to jack score yeah i can say something but someone else from jack's core interrupted me if they have a like really good answer i never quite know you know um you know answer a question like that other than to say that like uh the folks who are working on jax uh myself included like we love this project we're like pouring so much of ourselves into it and so you know that to me that's sort of sort of like as much of a guarantee of uh of you know it's not going away um as you can get just because you know we're personally very invested um i guess the jax team has grown and so you know and it's working well things like places like d-mind and google brain and uh research more broadly so um you know i don't think anyone wants to like uh uh i think everyone is is uh the best days of jack's are ahead i guess what i'm saying jax is growing uh the folks working on it are very passionate uh as you've seen and so you know i wouldn't worry at all about um jack's going away or something like that awesome and we are we are at time but we might have time for one more question um and it looks like robin had his hand raised oh yeah this looks amazing i was wondering uh what do you think are the stumbling blocks that new people to jax often hit upon like are there certain things that are more that are often confusing for people or cause issues when they're just getting started we have a great link for that there's something called the jacks the sharp bits um collab on the docks and then i'll i'll i'll link that in the chat in addition i guess i'll add one from the d mine side i think sort of you know as like one of the strengths obviously your jacks is this um this bet on functional programming and sort of you're raising everything into a functional paradigm where things are stateless and you don't need to worry about side effects um i guess as normal like i guess we decided to sort of you know diverge from that slightly with some of our libraries such as haku and the reason was you know we had other considerations we needed to balance we had so many users that are already using um sonnets as their way of defining and reasoning about neural networks within tensorflow and we really wanted to be able to maintain that api so certainly there's some cognitive overhead in you know sort of writing something that looks object oriented which is sort of converted for you into functions and sometimes it takes some users um i uses some time to ramp up to but i mean that was just sort of the the trade-off that we decided worked worked for us as an organization um certainly there are other jacks neural network libraries that take different approaches that other people find work better for them and you know in the spirit of incremental buy and cross-compatibility you know you're free to use any part of any of those with any of our libraries as well i think for the most part that's been working for people thank you i would add maybe uh a common bit that was also uh mentioned on the chat uh is the pseudo-random number generation um that is something that uh as a first newcomer to jax might might be a little bit uh surprising if you're used to numpy where you don't have to reason about the state of the random number generator almost at all i would say though that while it can be surprising at first it actually pays off in the long run so once you get used to actually this way of of dealing with random number generation i i think and it is in some sense at least personally i kind of ended up with the conclusion that this was actually the right way all along in some sense and that it there is also a very nice um document in in the jacks documentation about the the prng design explaining the reasons and the motivations for for the the way jax deals with random numbers i think that is a great read i think as a newcomer if you're maybe a bit puzzled at the beginning by random numbers just read that doc because it will it will think clarify a lot yeah i wholeheartedly agree with what matteo is saying like once you've used jack's rngs like all other rngs that maintain their own stay silent they become very scary all of a sudden so definitely cool so um so we we are at time um the thank you all so much for your questions i there might be one more slide left uh is is this correct mateo haha there we go cool so so thank you everybody for coming um absolutely please make sure to share your jax projects um on social media whatever your favorite flavor of social media is um github is also social um and it is my favorite social uh hangout place um but anything that you create please tag it with jack's ecosystem so that we can make sure to see it and to share it and if you're interested in contributing there are many issues that are good for first contributors on the jax repo many places to get involved and we look forward to seeing what you create thank you so much everyone for attending thank you to jack's core and for all of the jax ecosystem teams for presenting your work and for getting us all very excited i know i am significantly more excited somehow than i was even an hour ago so thank you everybody have a great day have a great nurips and see you soon\n"