Seminar – Jeremy Bernstein – Metrized Deep Learning
March 31 @ 11:00 am - 12:00 pm
Jeremy Bernstein
MIT CSAIL
Monday, March 31
11:00 AM – 12:00 PM (PST)
CSE 1242
Title: Metrized Deep Learning
Abstract:
We build neural networks in a modular and programmatic way using software libraries like PyTorch and JAX. But optimization theory has not caught up to the flexibility of this paradigm, and practical advances in neural net optimization are largely driven by heuristics. In this talk, I will argue that to treat deep learning rigorously, we must build our optimization theory programmatically and in lockstep with the neural network itself. To instantiate this idea we propose the “modular norm”, which is a norm on the weight space of general neural architectures. The modular norm is constructed by stitching together norms on individual tensor spaces as the architecture is constructed. The modular norm has several applications: automatic Lipschitz certificates for general architectures in both weights and inputs; automatic learning rate transfer across scale; and most recently we built the duality theory for the modular norm, leading to fast optimizers like “Muon”, which set speed records for training transformers. We are building the theory of the modular norm into a software library called Modula to ease the development and deployment of metrized deep learning algorithms—you can find out more at https://modula.systems/.
Biosketch:
Jeremy Bernstein is a postdoc in CSAIL at MIT advised by Phillip Isola. His goal is to uncover the computational and statistical laws of natural and artificial intelligence, and thereby design learning systems that are more efficient, more automatic and more useful in practice. He has a PhD in Computation & Neural Systems from Caltech and Bachelor’s and Master’s degrees in Physics from the University of Cambridge. He was a recipient of the NVIDIA graduate fellowship.