[00:07] (7.60s)
All right. [music] Hello everyone.
[00:11] (11.04s)
How you guys doing? Welcome to the first
[00:14] (14.32s)
ever YC paper club. This is like a very
[00:17] (17.92s)
exciting thing. [applause]
[00:21] (21.52s)
Absolutely thrilled with the response.
[00:23] (23.36s)
We had over a thousand folks that
[00:24] (24.96s)
applied to come in. It was a very hard
[00:27] (27.12s)
selection. If you guys have friends that
[00:28] (28.88s)
didn't make the cut, I'm very sorry.
[00:30] (30.48s)
We're we kind of we need to keep it to
[00:32] (32.40s)
about a hundred. Um and so we selected a
[00:35] (35.44s)
very very cool group. Um
[00:38] (38.72s)
the mission is to create this kind of
[00:41] (41.28s)
community of great founders and great
[00:45] (45.92s)
researchers and try to pull them
[00:47] (47.20s)
together. I guess just for you guys to
[00:49] (49.04s)
get a sense for how cool the people in
[00:51] (51.28s)
this room are. Um, raise your hand if
[00:54] (54.08s)
you have at least five citations,
[00:59] (59.36s)
10 citations,
[01:02] (62.08s)
a 100 citations,
[01:04] (64.88s)
a thousand citations.
[01:08] (68.16s)
Wow, this is insane. Okay, 10,000
[01:10] (70.56s)
citations. Oh my god. Okay. All right.
[01:14] (74.32s)
This is awesome. I I would go up to
[01:15] (75.92s)
300,000, but I think it's like Chris
[01:17] (77.28s)
Manning and that's about it. Um, so, uh,
[01:20] (80.48s)
raise your hand if you've raised at
[01:21] (81.76s)
least a million dollars.
[01:24] (84.64s)
Raise your hand if you've re raised at
[01:26] (86.16s)
least $5 million.
[01:28] (88.88s)
At least $10 million,
[01:31] (91.92s)
at least $50 million.
[01:35] (95.04s)
We still got one. We still got two over
[01:36] (96.72s)
here. All right. [laughter] Okay.
[01:39] (99.12s)
Awesome. The hidden mission that I'll
[01:41] (101.12s)
also kind of add on this is we had uh
[01:43] (103.20s)
Har and I had um this uh awesome uh
[01:46] (106.96s)
breakfast in uh Woodside and this place
[01:49] (109.68s)
is so so unique and special and we kind
[01:53] (113.04s)
of just don't use it enough at YC. So
[01:55] (115.04s)
the hidden mission is to make Pioneer
[01:56] (116.56s)
great again. And so I went through
[01:58] (118.32s)
winter 16 here. Um it was an
[02:01] (121.12s)
unbelievable time. I think 140 companies
[02:04] (124.00s)
went through that batch. 10 of 15 of
[02:06] (126.08s)
them are unicorns. It's an insane
[02:07] (127.84s)
number. um WPY, uh Astronis, um Deep
[02:12] (132.00s)
Graham, all these companies were in the
[02:13] (133.52s)
batch and during that time uh Sam was
[02:16] (136.64s)
still running the show and basically
[02:19] (139.52s)
sitting right there would be me,
[02:21] (141.52s)
Undercarpathy, Vaj Deremba and Greg
[02:24] (144.16s)
Brockman because they were starting this
[02:25] (145.60s)
thing called OpenAI and it was like the
[02:27] (147.36s)
very early stages and there was like not
[02:29] (149.68s)
that many AI companies. So they would
[02:31] (151.52s)
ask me and Steve from Debb like what are
[02:34] (154.16s)
you guys what are you working on? What
[02:35] (155.28s)
are the problems you're working on? and
[02:36] (156.56s)
they're looking for problems because
[02:37] (157.52s)
they didn't even know what to research.
[02:38] (158.80s)
And so it was such a such a special
[02:40] (160.24s)
time. This place is so special uh to to
[02:43] (163.12s)
me in particular uh to Har as well. And
[02:45] (165.92s)
we just it's it we don't really use it
[02:48] (168.00s)
enough. So I wanted um to kind of make
[02:50] (170.24s)
this community down here. And I also
[02:52] (172.32s)
think that 100% of the AI talent or AI
[02:56] (176.48s)
people in the Bay Area, probably about
[02:58] (178.80s)
half of them are in the city maybe is a
[03:00] (180.64s)
good number. There's anthropic, uh
[03:02] (182.40s)
there's open AI, there's cursor, there's
[03:04] (184.16s)
all this stuff in the city. Then there's
[03:05] (185.52s)
a lot that are down here that are not
[03:07] (187.28s)
making the trek up to the city to join
[03:09] (189.04s)
YC. And so he's like, "Yes,
[03:10] (190.72s)
emphatically, yes." Um, and so you have
[03:13] (193.36s)
Google DeepMind right on the corner. You
[03:15] (195.04s)
have um Tesla, you have XAI, you have
[03:17] (197.28s)
Thinking Machines, you have all these
[03:18] (198.24s)
other people in Palo Alto, you have a
[03:20] (200.08s)
lot of startups. And so uh I wanted to
[03:22] (202.40s)
kind of like solve six birds with one
[03:24] (204.48s)
stone and kind of pull together this
[03:26] (206.40s)
community down here as well. And Harj uh
[03:28] (208.64s)
uh is super excited about it as well.
[03:30] (210.48s)
And so thank you very much Har for
[03:31] (211.92s)
letting us do this. We got uh five great
[03:34] (214.08s)
papers here coming up. The first one is
[03:36] (216.56s)
Tanishk Speculative Speculative
[03:38] (218.64s)
Decoding. You want to come up?
[03:42] (222.64s)
All right.
[03:45] (225.76s)
Do you want me to pull it on? Yeah, I
[03:47] (227.20s)
got you.
[03:51] (231.12s)
I know it uh looks like maybe I was
[03:52] (232.96s)
sloppy and I added an extra word in the
[03:54] (234.64s)
title, but uh it is intentional um and
[03:56] (236.96s)
it'll make sense in uh good time. Um my
[03:59] (239.52s)
name is Tanishk. I'm a grad student at
[04:00] (240.88s)
Stanford. Um, this is a project I worked
[04:02] (242.56s)
on with Triau and Aar May. I'm going to
[04:05] (245.60s)
be evangelizing inference for people
[04:08] (248.40s)
today. Hopefully, you'll be inference
[04:10] (250.80s)
enjoyers by the end. So, I'm not sure
[04:13] (253.84s)
how much I have to motivate inference. I
[04:16] (256.48s)
worked on training before inference. And
[04:18] (258.16s)
I sort of the sort of mental model I had
[04:20] (260.32s)
in mind for how inference works was you
[04:22] (262.08s)
know you do this beautiful craftsmanship
[04:24] (264.32s)
during the training process and you get
[04:25] (265.76s)
these like you know very intricate
[04:27] (267.52s)
weights and then you kind of just hand
[04:29] (269.76s)
it off and use them to generate tokens.
[04:31] (271.52s)
In my mind it's sort of like you have
[04:34] (274.16s)
the weights just multiply the matrices
[04:36] (276.24s)
it's why do you need a team for it? Um I
[04:40] (280.00s)
was very confused but there is in fact a
[04:41] (281.76s)
lot of subtlety involved. Um it's a lot
[04:43] (283.92s)
of fun the algorithms and systems behind
[04:45] (285.92s)
inference at scale. I'm not sure I need
[04:48] (288.32s)
to spend too long talking about why
[04:50] (290.56s)
inference is important. Um there is one
[04:53] (293.44s)
point I want to make that I don't hear
[04:54] (294.96s)
people talk about enough. So things you
[04:57] (297.68s)
may have heard are that inference costs
[05:00] (300.80s)
are high. They dominate training costs
[05:02] (302.40s)
when you're serving a model for billions
[05:04] (304.64s)
of users or you know 10 claud code power
[05:09] (309.04s)
users. That's trillions of tokens. Um,
[05:12] (312.24s)
not only are inference costs dominating
[05:14] (314.16s)
training costs, but even within
[05:16] (316.40s)
training, RL is starting to exceed the
[05:19] (319.68s)
compute requirements of pre-training.
[05:21] (321.60s)
And what is RL but a wrapper on
[05:23] (323.28s)
inference, right? So, these are two
[05:26] (326.80s)
things you've probably heard before. The
[05:28] (328.40s)
third is one I fear isn't really talked
[05:31] (331.12s)
about, but it's the reason that I
[05:32] (332.56s)
started working on inference, and I use
[05:35] (335.52s)
the phrase working on inference lightly.
[05:37] (337.12s)
This was the only inference project I've
[05:38] (338.64s)
ever done. Um, but the the reason I got
[05:41] (341.68s)
interested in making inference fast was
[05:43] (343.44s)
not because of cost or for convenience.
[05:45] (345.76s)
It was entirely because of capability.
[05:48] (348.48s)
So the claim I'm going to make and maybe
[05:50] (350.00s)
this is the one thing to take away from
[05:52] (352.00s)
the message I'm trying to send in this
[05:53] (353.44s)
talk is that inference today is seen as
[05:57] (357.12s)
a sort of like cost or convenience
[05:59] (359.28s)
lever. But uh in one two or 3 years
[06:02] (362.64s)
inference is going to be seen as a
[06:04] (364.24s)
capability. And what I mean by that is
[06:06] (366.88s)
that if you have a method, an algorithm,
[06:09] (369.44s)
a system where its performance scales
[06:11] (371.84s)
with the amount of thinking it does,
[06:14] (374.24s)
then fundamentally the speed at which
[06:16] (376.40s)
you can do inference, the tokens per
[06:18] (378.32s)
second is exactly the peak intelligence
[06:20] (380.96s)
that you can deliver.
[06:23] (383.36s)
So inference should be thought of as not
[06:24] (384.96s)
so much as a a cost or or convenience
[06:27] (387.52s)
factor, but as a capability. Um, and
[06:29] (389.76s)
that's why I got interested in it. I I
[06:31] (391.28s)
wanted to work towards the future where
[06:32] (392.56s)
we have an entire data data center of
[06:34] (394.88s)
20,000 B200s just working on the reman
[06:37] (397.68s)
hypothesis. Um okay, yes, that's the
[06:42] (402.24s)
future that uh I had in mind. Perhaps
[06:44] (404.00s)
this meme is a little outdated because
[06:45] (405.28s)
it has an A100 on it, but uh yeah. Okay.
[06:49] (409.12s)
So to motivate things, here is an
[06:51] (411.84s)
example of fast inference. So I'm going
[06:54] (414.80s)
to give you a little demo of uh three
[06:56] (416.72s)
algorithms side by side. We're going to
[06:58] (418.24s)
sample, you know, a code prompt from VLM
[07:01] (421.28s)
with just normal auto reggressive
[07:02] (422.64s)
decoding. We're going to use their
[07:04] (424.48s)
speculative decoding. And then I'm going
[07:06] (426.24s)
to put next to it the sort of janky
[07:08] (428.56s)
handrolled inference engine I wrote over
[07:10] (430.32s)
a summer for this project. Um, whose
[07:12] (432.40s)
main strength is just that it implements
[07:13] (433.84s)
a new algorithm and so you can see them
[07:16] (436.40s)
side by side. SSDs on the right and you
[07:18] (438.96s)
can see it is quite a bit faster than
[07:21] (441.20s)
what you can get if you try to use an
[07:22] (442.64s)
open source engine. Um, and it's not the
[07:24] (444.80s)
systems, it's it's the algorithm. Um so
[07:27] (447.20s)
yeah that's what we want to work towards
[07:28] (448.64s)
understanding both how speculative
[07:30] (450.40s)
decoding works as well as the algorithm
[07:32] (452.56s)
on the right.
[07:34] (454.72s)
Okay. Um I'll start by introducing what
[07:38] (458.08s)
speculative decoding is how it works and
[07:39] (459.92s)
then we'll move into what speculative
[07:42] (462.16s)
speculative decoding is. I hope that if
[07:44] (464.56s)
you have like a reasonably strong
[07:46] (466.16s)
understanding of how speculative
[07:47] (467.76s)
decoding works the the problem that SSD
[07:50] (470.40s)
is trying to solve will feel very
[07:51] (471.44s)
motivated and and the algorithm should
[07:52] (472.88s)
just become clear in good time.
[07:56] (476.32s)
Okay, so this is the schematic I'm going
[07:57] (477.92s)
to use to explain how vanilla
[07:59] (479.52s)
speculative decoding works. Um, it has a
[08:02] (482.64s)
small model, the tiny llama up top, as
[08:04] (484.96s)
well as a big model, the big llama. And
[08:07] (487.12s)
our goal is simply to sample fast from
[08:10] (490.16s)
the big llama. We want tokens generated
[08:12] (492.16s)
from the big model. And we're going to
[08:13] (493.20s)
use a small model as a sort of proxy or
[08:14] (494.96s)
an instrument to be able to sample
[08:16] (496.32s)
quickly from the big model. Okay. So,
[08:18] (498.80s)
what the draft is going to be
[08:20] (500.00s)
responsible for is basically generating
[08:22] (502.24s)
a bunch of tokens one by one. One by one
[08:25] (505.12s)
is important. It's auto reggressive. So
[08:26] (506.56s)
you need to do three forward passes on
[08:27] (507.92s)
the draft or you know however many some
[08:29] (509.76s)
constant number. Um and these are going
[08:31] (511.76s)
to be guesses for what the draft
[08:33] (513.92s)
believes that the big model is going to
[08:36] (516.08s)
output next. It wants to sort of predict
[08:38] (518.08s)
ahead of time. The job that the big
[08:40] (520.08s)
model has, I'm going to call it the
[08:42] (522.08s)
target model, is verifying these
[08:43] (523.84s)
guesses. What does verification mean?
[08:46] (526.24s)
Verification means doing one forward
[08:48] (528.48s)
pass over these generated tokens to see
[08:51] (531.76s)
how likely it is that the big model
[08:53] (533.76s)
would have generated them. The sort of
[08:56] (536.08s)
key asymmetry here, the reason that
[08:57] (537.76s)
speculation works is that it is easier
[09:01] (541.76s)
to verify than to generate. This is a
[09:04] (544.64s)
feature of the transformer architecture
[09:06] (546.16s)
where you can get the probabilities for
[09:07] (547.84s)
many tokens in a sequence in parallel in
[09:09] (549.76s)
one forward pass. Um but you can't
[09:11] (551.92s)
generate them in parallel. auto
[09:13] (553.68s)
reggressive decoding as uh one at a
[09:16] (556.24s)
time. Um so we're leaving the auto
[09:17] (557.84s)
reggressive decoding which is slow uh to
[09:20] (560.32s)
a very quick and small model and then
[09:22] (562.40s)
we're doing just one forward pass on
[09:24] (564.32s)
these tokens. And the way you verify
[09:26] (566.40s)
tokens is basically by having the big
[09:28] (568.56s)
model look at the probabilities of each
[09:30] (570.72s)
of the generated tokens and see how
[09:32] (572.80s)
plausible it is that it would have
[09:34] (574.80s)
generated those tokens. And sort of the
[09:36] (576.96s)
intuition here is that we will accept
[09:39] (579.20s)
precisely those tokens that the big
[09:40] (580.96s)
model could plausibly have generated.
[09:43] (583.28s)
Its probabilities were reasonably high.
[09:45] (585.20s)
There subtleties in exactly what the
[09:46] (586.48s)
algorithm is um that I'm going to gloss
[09:48] (588.40s)
over, but that's the way to think about
[09:49] (589.52s)
it. Um and then we're going to find a
[09:51] (591.20s)
point perhaps where we don't think it's
[09:52] (592.72s)
plausible the big model would have
[09:53] (593.92s)
generated those tokens and we're going
[09:55] (595.36s)
to reject those tokens. So in the little
[09:57] (597.20s)
schematic on the right uh there the
[09:59] (599.36s)
draft samples three and the big model
[10:02] (602.16s)
verifies them and concludes that only
[10:03] (603.60s)
the first token was something it would
[10:05] (605.12s)
plausibly have generated. It will reject
[10:06] (606.96s)
the second token onwards and importantly
[10:10] (610.32s)
this is a sort of critical but subtle
[10:12] (612.40s)
detail of vanilla specular decoding
[10:14] (614.80s)
because you have the probabilities at
[10:16] (616.24s)
each of the sequence positions. You can
[10:18] (618.08s)
sample an extra token at the point at
[10:20] (620.64s)
which you rejected a token for free as
[10:23] (623.76s)
in without doing any more forward
[10:25] (625.12s)
passes. And so that yellow token is what
[10:27] (627.12s)
I'm going to call a bonus token that you
[10:28] (628.56s)
sample for free. This is going to be
[10:29] (629.92s)
important in SSD. Um, so yeah, that's uh
[10:33] (633.44s)
that's an important conceptual point.
[10:37] (637.52s)
this sort of sets the stage for how SSD
[10:41] (641.36s)
works. Okay, we have our schematic.
[10:45] (645.52s)
And the way we've set up speculative
[10:47] (647.68s)
decoding is that it's a way to exchange
[10:49] (649.76s)
flops for latency. So speculation in
[10:51] (651.92s)
general is not actually something that
[10:54] (654.32s)
uh only LLMs do. It's like a a deep idea
[10:56] (656.80s)
in computer science. It's used in CPUs
[10:58] (658.48s)
as well where the general philosophy is
[11:00] (660.72s)
that you premputee something ahead of
[11:02] (662.24s)
time. Some of what you premputee may be
[11:04] (664.64s)
useless because it may be an incorrect
[11:06] (666.32s)
prediction of the future, but if you're
[11:08] (668.00s)
right, you get to fast forward in time
[11:10] (670.08s)
um and you get lower latency as a
[11:11] (671.60s)
result. So the the sort of like moral
[11:13] (673.52s)
philosophy of speculative decoding is
[11:14] (674.80s)
that it's currency exchange. The
[11:16] (676.40s)
difficulty with normal speculative
[11:18] (678.08s)
decoding is that you can't push this
[11:21] (681.04s)
arbitrarily far. You cannot keep
[11:23] (683.20s)
sampling more and more tokens on the
[11:24] (684.72s)
draft and keep getting speed ups because
[11:26] (686.56s)
at some point you're going to get to a
[11:27] (687.84s)
point where you're spending a lot of
[11:28] (688.80s)
time drafting and you're not accepting
[11:30] (690.40s)
all that many tokens. And in particular,
[11:32] (692.40s)
like a big bottleneck in vanilla
[11:33] (693.84s)
speculative decoding is the sequential
[11:35] (695.52s)
dependence between the small llama and
[11:36] (696.88s)
the big llama. Um the drafting in round
[11:39] (699.28s)
t has to take place before the
[11:41] (701.36s)
verification of those tokens. um and the
[11:44] (704.24s)
drafting in round t+1 can't take place
[11:46] (706.72s)
before you know the outcome of
[11:48] (708.00s)
verification of the previous round
[11:49] (709.84s)
because you need that as a prefix to
[11:51] (711.44s)
draft on top of. So there's a logical
[11:54] (714.24s)
dependency here. The goal of SSD is very
[11:57] (717.52s)
simple. There's a lot of gnarly and
[11:59] (719.52s)
subtle details but the highle idea is
[12:01] (721.28s)
incredibly simple. It is simply to
[12:03] (723.76s)
parallelize this sequential operation.
[12:06] (726.96s)
We want drafting and verification to be
[12:10] (730.24s)
happening at the same time.
[12:12] (732.80s)
Normally in speculation they happen on
[12:14] (734.32s)
the same hardware and that's fine
[12:15] (735.84s)
because there's only one of them
[12:16] (736.88s)
happening at a time. In our setup
[12:18] (738.72s)
they're going to be happening at the
[12:19] (739.60s)
same time. So we're not going to be
[12:20] (740.72s)
collocating them. And the main question
[12:23] (743.60s)
basically becomes how do you parallelize
[12:26] (746.24s)
this inherently sequential algorithm
[12:28] (748.64s)
that has a logical dependency. Um and
[12:30] (750.64s)
the way we're going to do that is we are
[12:32] (752.40s)
going to have the draft model send back
[12:34] (754.80s)
its draft tokens in a certain round. So
[12:37] (757.60s)
we've sent back a bunch of blue tokens.
[12:39] (759.68s)
That's now the job of the verifier to do
[12:42] (762.24s)
a forward passover and verify. And this
[12:44] (764.64s)
is going to take a while because a
[12:45] (765.84s)
verifier is a big model. What we on the
[12:48] (768.08s)
draft are going to do is basically start
[12:50] (770.88s)
anticipating the most likely
[12:52] (772.32s)
verification outcomes immediately.
[12:55] (775.04s)
As soon as we send back like a certain
[12:57] (777.04s)
round of speculation and once we we have
[12:59] (779.68s)
in mind some of the most likely
[13:01] (781.52s)
verification outcomes, we are going to
[13:03] (783.44s)
start drafting the next round on top of
[13:05] (785.76s)
those immediately while verification is
[13:08] (788.00s)
taking place. If we're right, the next
[13:10] (790.16s)
time the verifier asks for a draft,
[13:12] (792.48s)
we'll have it ready immediately. We're
[13:14] (794.48s)
entirely hiding the latency of drafting.
[13:16] (796.32s)
If we're wrong, well, we'll have to
[13:17] (797.76s)
figure out a backup strategy. And
[13:19] (799.12s)
there's uh there's there's there's some
[13:20] (800.96s)
subtleties on what you do and how you do
[13:22] (802.40s)
it there. Um so yeah, the way that
[13:25] (805.04s)
speculative decoding looks like this.
[13:26] (806.56s)
And perhaps unsurprisingly, the analog
[13:29] (809.28s)
for SSD is this diagram on the right.
[13:32] (812.00s)
We're now drafting and verification
[13:33] (813.76s)
happen in parallel. um the the principal
[13:37] (817.84s)
difficulty or algorithmic design space
[13:40] (820.00s)
in SSD is how do you predict
[13:42] (822.08s)
verification outcomes ahead of time. I
[13:44] (824.40s)
thought verification is where you are
[13:46] (826.40s)
leveraging the intelligence of the big
[13:47] (827.76s)
model that should by construction be
[13:49] (829.36s)
difficult to predict. Um and the
[13:51] (831.12s)
intuition for why it's plausible at all
[13:53] (833.44s)
is that you can make many guesses on the
[13:55] (835.44s)
draft for what a verification outcome
[13:57] (837.36s)
is. And a verification outcome here is
[13:59] (839.76s)
just you know a plausible number of
[14:01] (841.12s)
accepted tokens and then a bonus token
[14:04] (844.16s)
on top of that. Now this is hard to
[14:06] (846.16s)
predict because a bonus token comes from
[14:07] (847.76s)
a vocabulary which has size you know
[14:09] (849.44s)
tens to hundreds of thousands. Um so
[14:11] (851.28s)
it's a large space to cover um but it
[14:13] (853.68s)
turns out you can do it well um
[14:15] (855.44s)
reasonably well. You can get it right
[14:16] (856.80s)
about 80 to 90% of the time which is
[14:18] (858.56s)
more than enough to get big speed ups.
[14:20] (860.48s)
And the way we do that, the short of it
[14:22] (862.40s)
is basically we use information on the
[14:24] (864.48s)
draft to predict what the verification
[14:26] (866.48s)
outcome is likely to be. When we
[14:27] (867.92s)
generated the blue tokens on the draft,
[14:29] (869.36s)
we had other tokens that we chose not to
[14:31] (871.12s)
sample. Those other tokens are plausible
[14:33] (873.52s)
verification bonus token candidates. And
[14:35] (875.92s)
so you basically use information from
[14:38] (878.40s)
the token distributions of the draft
[14:40] (880.40s)
model to predict what likely outcomes on
[14:42] (882.40s)
the target are. And then once you have
[14:44] (884.08s)
all of these predictions, you can decode
[14:45] (885.76s)
them in parallel as just different
[14:47] (887.28s)
sequences that you're decoding on top of
[14:49] (889.60s)
a shared prefix. And voila, it uh it's
[14:53] (893.12s)
it gives you speedups because you get to
[14:54] (894.88s)
hide the latency of drafting altogether.
[14:57] (897.12s)
Um there's also a an additional bonus
[15:00] (900.00s)
that since verification actually kind of
[15:01] (901.52s)
takes a while, you get more time to
[15:04] (904.08s)
draft uh in the first place. So you can
[15:05] (905.84s)
draft more tokens which increases the
[15:07] (907.68s)
expected tokens per round and sort of
[15:09] (909.68s)
gives you further speed ups. There's a
[15:11] (911.92s)
bunch of stuff that we work through in
[15:13] (913.44s)
the paper that's uh that's sort of
[15:14] (914.88s)
reckoning with the the implementation
[15:17] (917.28s)
details of this. One of it is how you
[15:19] (919.20s)
handle cache misses. One plausible thing
[15:21] (921.76s)
you could do perhaps naively is to just
[15:23] (923.44s)
fall back to ordinary speculation just
[15:24] (924.88s)
in time. Turns out that actually this is
[15:26] (926.72s)
not always optimal. Um there's
[15:28] (928.64s)
trade-offs. You know, as batch size
[15:30] (930.24s)
increases, you're going to fail to
[15:31] (931.60s)
predict some of the sequences
[15:32] (932.88s)
verification outcomes. Um and so you
[15:35] (935.12s)
need different ways to predict and
[15:36] (936.56s)
handle cache misses. Should you be
[15:38] (938.72s)
allocating your compute on the draft
[15:41] (941.60s)
equally amongst plausible
[15:44] (944.40s)
prefix length? Uh the short answer is
[15:46] (946.40s)
no. You can be clever about it. And all
[15:48] (948.56s)
of this trickery just helps you increase
[15:50] (950.80s)
your cash hit rate, so to speak, the
[15:53] (953.60s)
amount of time you're able to correctly
[15:55] (955.12s)
predict verification outcomes. And
[15:57] (957.20s)
there's there's some trade-offs between
[15:58] (958.88s)
cash hit rate and the actual quality of
[16:01] (961.28s)
the drafting you're doing. Um and this
[16:03] (963.52s)
is totally non-obvious. Um, and and and
[16:06] (966.00s)
we we go into why that exists and how
[16:07] (967.52s)
you can navigate it in the paper. Um,
[16:09] (969.36s)
I'm happy to talk about it in in in Q&A
[16:11] (971.28s)
as well. Um, okay. So, what do you get
[16:15] (975.84s)
for the the price of this uh
[16:18] (978.56s)
mind-numbing
[16:20] (980.56s)
complexity and uh pain wrangling an
[16:22] (982.88s)
inference engine? Well, you get the
[16:25] (985.12s)
privilege of watching a number go up,
[16:27] (987.04s)
which I guess is the north star of all
[16:29] (989.44s)
AI research. And so here we have uh a
[16:33] (993.04s)
bunch of inference algorithms and
[16:34] (994.40s)
inference engines. The blue ones are
[16:36] (996.88s)
sort of uh my inference engine and uh
[16:39] (999.28s)
the light blue is just the baseline
[16:41] (1001.52s)
implementation of speculative decoding.
[16:42] (1002.96s)
The red is SG lang which is you know of
[16:46] (1006.00s)
all the inference engines we tried the
[16:47] (1007.44s)
fastest with speculative decoding and
[16:49] (1009.04s)
the dark blue is is SSD. Um and normally
[16:52] (1012.32s)
speculative decoding um is a is a win
[16:54] (1014.80s)
for latency but it's sort of unclear
[16:56] (1016.56s)
whether it's useful for throughput. um
[16:58] (1018.64s)
for us it turn in in in this setting
[17:00] (1020.16s)
it's actually a win for both um and so
[17:02] (1022.48s)
you get numbers going up and you also
[17:04] (1024.40s)
get the ability next time you are at a
[17:06] (1026.80s)
San Francisco house party um to see
[17:08] (1028.88s)
other people dancing and knowing in the
[17:11] (1031.04s)
corner that uh you know what it takes to
[17:14] (1034.08s)
sample at 300 tokens per second uh for
[17:16] (1036.48s)
llama 370B on 4H100s. So this is uh
[17:19] (1039.60s)
sensitive information um but yeah that's
[17:22] (1042.64s)
that's about it. YOU. [applause]
[17:28] (1048.67s)
[applause]
[17:32] (1052.16s)
All right, that was awesome. Okay, so
[17:35] (1055.28s)
for this next paper,
[17:39] (1059.28s)
this is um my first experience being
[17:42] (1062.56s)
scooped. The only issue is that he
[17:44] (1064.96s)
didn't talk to me and he did it six
[17:46] (1066.56s)
months before me. Um [laughter]
[17:49] (1069.12s)
but uh Isaac can vouch for me on this
[17:51] (1071.76s)
and maybe Robert as well. I basically
[17:54] (1074.80s)
fell in love with the diffusion policy
[17:56] (1076.16s)
paper. I was like this is definitely
[17:58] (1078.16s)
like you know a full uh predicting like
[18:01] (1081.52s)
th horizon steps for your robotic
[18:04] (1084.72s)
control. Um we have these amazing video
[18:07] (1087.12s)
models. Why don't we just use the video
[18:08] (1088.56s)
model to like run this like at test time
[18:11] (1091.52s)
to like play out the movie and where do
[18:14] (1094.08s)
I end up? And then you have your classic
[18:15] (1095.60s)
push t. And then I started like looking
[18:18] (1098.08s)
around uh and then DM mind of course
[18:20] (1100.48s)
already did it. So [laughter]
[18:23] (1103.28s)
so I wasted like a month and it was not
[18:24] (1104.88s)
happy. But anyway, thank you very much.
[18:26] (1106.88s)
Please welcome Stannis.
[18:33] (1113.60s)
>> Hi everyone. I'm Stannis. I'm a star
[18:35] (1115.68s)
research scientist at Google DeepMind.
[18:38] (1118.00s)
Uh currently I'm co-leading a new
[18:39] (1119.92s)
project on word modeling for robotics.
[18:42] (1122.08s)
uh where we try to build general purpose
[18:44] (1124.16s)
policies on top of video and word
[18:46] (1126.16s)
models. But uh this is an early work
[18:48] (1128.40s)
that I did about two years ago. Uh so
[18:52] (1132.08s)
this is before I switched to working on
[18:54] (1134.32s)
hardcore robotics and uh going into
[18:56] (1136.48s)
hardware really scaling up the data but
[18:59] (1139.20s)
uh you can probably see a lot of very
[19:01] (1141.68s)
similar ideas early version of ideas
[19:04] (1144.40s)
demonstrated on toy problems. Okay. So
[19:08] (1148.24s)
uh first to give some background what is
[19:10] (1150.32s)
the model predictive control. So model
[19:12] (1152.72s)
predictive control also called the
[19:14] (1154.48s)
receding horizon control uses a dynamics
[19:16] (1156.88s)
model or some people also call it a word
[19:18] (1158.88s)
model and uh action selector mechanism
[19:22] (1162.48s)
uh which is a planner to construct
[19:24] (1164.08s)
agents that can solve a wide variety of
[19:26] (1166.16s)
tasks by means of maximizing a no
[19:29] (1169.04s)
objective. So the main advantages of
[19:32] (1172.32s)
model predictive control is uh it can
[19:34] (1174.80s)
adapt to normal reward functions at test
[19:36] (1176.96s)
time. So uh the dynamics model are also
[19:39] (1179.76s)
easier to learn and generates better
[19:41] (1181.76s)
than just policies and the action
[19:44] (1184.24s)
proposal dynamics model factorization
[19:46] (1186.08s)
also allows easy adaptation to normal
[19:49] (1189.44s)
dynamics. So we're going to uh
[19:51] (1191.36s)
demonstrate some of these in later
[19:53] (1193.28s)
experiments but basically here we are
[19:55] (1195.52s)
showing the overall idea which is
[19:57] (1197.28s)
extremely simple. We have a action
[19:59] (1199.28s)
proposal which proposes a sequence of
[20:01] (1201.36s)
actions. We have a dynamics model which
[20:03] (1203.84s)
can evolve these actions and give you
[20:06] (1206.00s)
the future states. And uh finally we
[20:08] (1208.16s)
have some objective functions that we
[20:10] (1210.08s)
are trying to optimize. We basically use
[20:12] (1212.56s)
a planner to optimize that and uh pick
[20:14] (1214.88s)
the actions and execute it in the
[20:17] (1217.36s)
environment. So what is diffusion model
[20:19] (1219.12s)
operative control? So the motivation
[20:22] (1222.08s)
mainly is uh uh there are a couple of
[20:24] (1224.40s)
problems we need to address in order to
[20:26] (1226.32s)
make MPC effective in practice. One the
[20:29] (1229.12s)
dynamics model need to be accurate to
[20:30] (1230.88s)
avoid the problem of compounding errors
[20:33] (1233.20s)
and uh two the planning algorithm also
[20:35] (1235.68s)
needs to be powerful enough to select a
[20:37] (1237.60s)
good sequence of actions. So with DMPC
[20:40] (1240.88s)
what we did is to use diffusion models
[20:43] (1243.44s)
to learn both multi-step action
[20:45] (1245.68s)
proposals and multi-step uh dynamics
[20:48] (1248.72s)
models. So the advantages are mainly to
[20:51] (1251.92s)
reduce compounding errors and we also
[20:54] (1254.48s)
found that uh it can simplify the
[20:56] (1256.48s)
planning algorithm. Essentially we can
[20:58] (1258.08s)
just use a very simple uh sampling based
[21:00] (1260.56s)
planner and we can already outperform a
[21:02] (1262.72s)
lot of the previous uh approaches. So uh
[21:05] (1265.84s)
before we dive into the details also
[21:07] (1267.44s)
want to give a hierarchical view of some
[21:09] (1269.52s)
related works we organized. So there are
[21:12] (1272.24s)
a lot of related works in the literature
[21:14] (1274.16s)
and uh we organize it uh uh in this way
[21:17] (1277.12s)
where we basically look at how different
[21:19] (1279.28s)
approaches um so basically all
[21:21] (1281.60s)
approaches essentially try to build a
[21:23] (1283.60s)
joint uh distribution of the states and
[21:26] (1286.80s)
the actions but they do it in different
[21:28] (1288.96s)
ways and also use the different
[21:31] (1291.20s)
components in different ways. So for
[21:32] (1292.88s)
example, you can build it in a
[21:34] (1294.40s)
factorized way where you have row a
[21:36] (1296.80s)
which is your policy predicting the
[21:39] (1299.12s)
actions and then collision on the action
[21:41] (1301.36s)
predict the state which is a dynamics
[21:43] (1303.12s)
model and uh for this you have the dynam
[21:45] (1305.76s)
paradigm where you basically learn a
[21:47] (1307.52s)
model and use the model to also generate
[21:51] (1311.20s)
data in the imagination and the learn
[21:52] (1312.96s)
policy. But uh you can also do MPC uh
[21:56] (1316.24s)
where you uh essentially use a planner
[21:59] (1319.04s)
to select the actions and uh we also
[22:01] (1321.92s)
have uh some uh uh there are also
[22:03] (1323.92s)
approaches where you build a joint model
[22:05] (1325.76s)
of the state and actions and you're
[22:07] (1327.68s)
essentially also doing MPC and there are
[22:09] (1329.76s)
also model free approaches where you
[22:11] (1331.52s)
directly learn a policy. uh I won't dive
[22:13] (1333.76s)
into the full details but uh uh there
[22:16] (1336.24s)
are basically different trade-offs in
[22:18] (1338.40s)
terms of runtime plan uh whether we can
[22:20] (1340.40s)
do runtime planning and uh adapting to
[22:22] (1342.96s)
normal rewards and adapting to normal
[22:25] (1345.12s)
dynamics leveraging non-expert data and
[22:27] (1347.84s)
also the uh general speed at runtime and
[22:31] (1351.60s)
there is also the distinction between
[22:33] (1353.68s)
whether you're doing singlestep modeling
[22:35] (1355.44s)
or multi-step modeling.
[22:38] (1358.08s)
Okay. So coming to diffusion model,
[22:40] (1360.32s)
diffusion model has enjoyed a lot of
[22:42] (1362.64s)
successes uh in uh generating AI
[22:45] (1365.44s)
especially for generating images and
[22:47] (1367.52s)
videos. But uh in recent years they also
[22:50] (1370.16s)
found a lot of successes in robotics. So
[22:52] (1372.88s)
currently uh so here I'm also showing a
[22:55] (1375.20s)
slide where uh this is a kind of the
[22:57] (1377.36s)
exploration space for uh diffusion based
[23:00] (1380.56s)
uh I would calling diffusion based
[23:02] (1382.16s)
agents. So we of course start with the
[23:04] (1384.96s)
diffusion policy where we condition all
[23:06] (1386.96s)
the observation and generate future
[23:09] (1389.04s)
actions. But then we also have this work
[23:11] (1391.60s)
called the diffuser which uh is uh you
[23:15] (1395.12s)
can think of it as a way to joint
[23:17] (1397.84s)
jointly model uh observations and states
[23:20] (1400.48s)
but in toy space. There are of course
[23:23] (1403.76s)
these ideas are explored in tons of
[23:25] (1405.84s)
different papers but this is just a very
[23:27] (1407.76s)
simple and uh uh conceptual way to
[23:30] (1410.96s)
describe it. And uh then there's also
[23:33] (1413.20s)
decision diffuser where we collision on
[23:35] (1415.68s)
the observations we directly generate
[23:37] (1417.44s)
future uh we condition on the history
[23:39] (1419.52s)
directly generate future observations
[23:41] (1421.44s)
and then try a separate inverse dynamics
[23:43] (1423.76s)
model to derive the actions and uh
[23:46] (1426.24s)
finally we have the diffusion model
[23:48] (1428.32s)
predictive control where we first have
[23:51] (1431.12s)
an action proposal to propose future
[23:53] (1433.04s)
actions and use a dynamics model to
[23:55] (1435.04s)
evolve it and uh then use planner to
[23:58] (1438.88s)
select the actions. There are different
[24:01] (1441.28s)
uh trade-offs among these. So for
[24:03] (1443.20s)
example, diffusion policy is sort of on
[24:06] (1446.24s)
complex uh complex control like
[24:08] (1448.40s)
day-to-day we still rely on it a lot.
[24:10] (1450.96s)
But this requires expert demonstrations.
[24:13] (1453.44s)
So essentially you can't move out of the
[24:15] (1455.92s)
behavior cloning paradigm. Uh for
[24:18] (1458.24s)
diffuser it's a jointly modeling state
[24:20] (1460.64s)
and action. So it has implicit word
[24:23] (1463.36s)
modeling and also model based planning.
[24:25] (1465.76s)
And this is actually something that we
[24:27] (1467.68s)
are trying to explore at scale similar
[24:30] (1470.00s)
ideas. But uh and then there's also uh
[24:32] (1472.96s)
decision diffuser where you do
[24:35] (1475.12s)
observation only learning. The main
[24:37] (1477.20s)
benefit of this is it allows you to
[24:39] (1479.84s)
leverage uh uh video only data to learn
[24:43] (1483.12s)
from video only data because for
[24:45] (1485.04s)
robotics uh the data is a many
[24:47] (1487.44s)
bottleneck. And then finally there's a
[24:49] (1489.36s)
division MPC which allows us to do
[24:51] (1491.76s)
runtime adaptation to normal rewards and
[24:54] (1494.56s)
normal dynamics. So what does the
[24:57] (1497.68s)
algorithm look like? It actually is
[24:59] (1499.76s)
extremely simple. We have uh often data
[25:02] (1502.48s)
set and uh we have uh some
[25:05] (1505.36s)
hyperparameters. Essentially we are
[25:07] (1507.52s)
learning a couple of u uh learning a
[25:10] (1510.72s)
couple of models all from the offline
[25:12] (1512.72s)
data sets. We're learning a policy which
[25:15] (1515.12s)
u uh given the current observation
[25:17] (1517.04s)
predicts the actions. We're learning a
[25:19] (1519.04s)
dynamics model which uh given the uh
[25:21] (1521.76s)
given the actions uh evolves the
[25:24] (1524.72s)
observations to predict the future
[25:26] (1526.72s)
states. And uh uh basically after
[25:29] (1529.76s)
learning all this at uh um at uh
[25:32] (1532.80s)
inference time when we actually deploy
[25:34] (1534.96s)
it as a policy we uh sampled action
[25:37] (1537.76s)
proposal and score it uh rank it and uh
[25:41] (1541.12s)
pick the best. But uh the main
[25:43] (1543.76s)
difference uh compared to previous
[25:45] (1545.68s)
approaches is uh we adopted a multi-step
[25:49] (1549.20s)
action proposal which uh is uh
[25:51] (1551.68s)
essentially very similar to a diffusion
[25:53] (1553.60s)
policy but if you train on more diverse
[25:55] (1555.92s)
data it can give you uh more coverage in
[25:58] (1558.48s)
terms of the action space and uh we are
[26:01] (1561.36s)
also using a multi-step
[26:04] (1564.40s)
um uh dynamics model which uh allows you
[26:07] (1567.44s)
to uh evolve for a long time horizon
[26:10] (1570.16s)
without a lot of compounding error. And
[26:13] (1573.12s)
uh this allows us uh to and also uh
[26:17] (1577.20s)
there's a fact that we leverage
[26:19] (1579.76s)
diffusion model which is a really
[26:21] (1581.52s)
powerful way to model data especially
[26:24] (1584.24s)
multimodel data and uh uh what we
[26:26] (1586.96s)
observed empirically is the uh stronger
[26:31] (1591.20s)
modeling uh capabilities also allows us
[26:34] (1594.16s)
uh to uh simplify the planning algorithm
[26:36] (1596.88s)
so that we can just use such a simple uh
[26:39] (1599.92s)
planner to do to solve the task. tasks.
[26:43] (1603.28s)
Yeah. Um also contrasting with a few of
[26:46] (1606.00s)
the representative uh uh path works uh
[26:48] (1608.96s)
including uh model based offline control
[26:51] (1611.76s)
offline planning and this diffuser work
[26:54] (1614.40s)
which I mentioned it learns a joint
[26:57] (1617.28s)
model and uses a classifier free
[26:59] (1619.52s)
guidance for planning.
[27:02] (1622.56s)
Okay. Uh so yeah next to dive into some
[27:06] (1626.96s)
uh results uh there are lots of numbers
[27:10] (1630.00s)
but the short answer is uh we obtain
[27:12] (1632.72s)
very competitive results in fixed reward
[27:15] (1635.12s)
single task setups. This is just to
[27:17] (1637.36s)
demonstrate that uh uh the approach uh
[27:20] (1640.40s)
when you deploy it in uh single reward
[27:23] (1643.20s)
uh fixed reward single task setup it can
[27:26] (1646.00s)
perform competitively to the current
[27:28] (1648.00s)
state-of-the-art uh previous
[27:30] (1650.16s)
state-of-the-art approaches. But uh I
[27:32] (1652.88s)
think uh there are a couple of uh more
[27:35] (1655.20s)
interesting uh properties of DMPC. One
[27:39] (1659.44s)
is it can adapt to no rewards at
[27:42] (1662.08s)
runtime. Here we are showing some uh
[27:44] (1664.72s)
examples where uh essentially we train
[27:47] (1667.44s)
the model to uh these are very simple
[27:50] (1670.40s)
modulo tasks but we train the model to
[27:52] (1672.80s)
just uh local motion tasks run forward
[27:55] (1675.68s)
and jump etc. But uh at inference time
[27:59] (1679.12s)
we can just by changing the reward
[28:01] (1681.20s)
function to uh make it uh exhibit uh
[28:04] (1684.72s)
novel behaviors like uh jumping etc. So
[28:09] (1689.28s)
uh here's another example where we show
[28:11] (1691.68s)
that uh uh DMPC can adapt to novel
[28:14] (1694.56s)
dynamics while uh this kind of uh joint
[28:17] (1697.44s)
modeling approaches struggle. This is
[28:19] (1699.76s)
really the benefit of the factorization
[28:22] (1702.24s)
of the action proposal and the dynamics
[28:24] (1704.88s)
model. So the here the idea is uh we can
[28:28] (1708.32s)
keep the action proposal the same but uh
[28:30] (1710.88s)
we uh we have uh scenarios where the
[28:34] (1714.48s)
dynamics of the environment changed. So
[28:36] (1716.72s)
for example the walker has a broken left
[28:39] (1719.04s)
ankle and as a result when it starts to
[28:41] (1721.52s)
execute actions the consequence of the
[28:44] (1724.40s)
actions change. So in such cases because
[28:47] (1727.12s)
of the factorized representation in DMPC
[28:50] (1730.32s)
we can uh simply just adapt the dynamics
[28:53] (1733.28s)
model on some play data collected in the
[28:56] (1736.88s)
new environment and uh we observe that
[28:59] (1739.52s)
we can recover a lot of the performance
[29:02] (1742.88s)
because of the changing dynamics.
[29:04] (1744.80s)
Finally, we dug into the various
[29:07] (1747.36s)
components of uh the DMPC design and we
[29:11] (1751.12s)
demonstrated that uh the different
[29:13] (1753.04s)
components in DMPC basically contributed
[29:15] (1755.60s)
to improved performance. Uh this uh
[29:19] (1759.20s)
these include uh the diffusion active
[29:21] (1761.52s)
proposals, action proposals, improve
[29:23] (1763.76s)
performance and simplify the planning.
[29:26] (1766.48s)
We do multi-step diffusion action
[29:28] (1768.72s)
proposals and the the fact that we do
[29:31] (1771.12s)
multi-step also uh contributes to
[29:33] (1773.60s)
improved performance and finally
[29:35] (1775.52s)
multi-step dynamics modeling also uh
[29:38] (1778.32s)
contributes to improved performance.
[29:41] (1781.60s)
Uh that's it.
[29:46] (1786.20s)
[applause]
[29:50] (1790.80s)
All right. And that was the last Google
[29:52] (1792.40s)
Deep Mind paper that they're going to
[29:53] (1793.92s)
publish. So, good luck out there. Um,
[29:57] (1797.12s)
this next one is one of my lab mates
[29:59] (1799.68s)
that I work with a lot that is the most
[30:02] (1802.56s)
world model pled person
[30:05] (1805.44s)
that I know. [laughter]
[30:08] (1808.24s)
And so, I can't imagine, you know,
[30:10] (1810.80s)
anyone else presenting this paper other
[30:12] (1812.80s)
than Yan Lun himself. Um, [laughter]
[30:16] (1816.80s)
Isaac Ward. There you go. Thanks a lot.
[30:22] (1822.88s)
>> All right, guys. Is Is that a good
[30:24] (1824.00s)
distance? You all can hear me at the
[30:25] (1825.04s)
back. Cool. Cool. Yeah, I'm enjoying a
[30:28] (1828.24s)
uh a cool little period in life where I
[30:30] (1830.56s)
started working on world models a couple
[30:31] (1831.92s)
years ago, kind of before they got
[30:33] (1833.20s)
really hot and now they're enjoying a
[30:34] (1834.48s)
moment in the sun and suddenly everyone
[30:36] (1836.24s)
wants to talk to me which is nice. I'm
[30:38] (1838.16s)
presenting lay world model which is a
[30:39] (1839.76s)
call out of course out of Yan Lacun's
[30:41] (1841.52s)
group. Uh QR code here if you want to
[30:43] (1843.20s)
follow along with the project page, but
[30:44] (1844.32s)
I'll explain through it and yeah, really
[30:46] (1846.24s)
excited to talk to you about this one.
[30:47] (1847.44s)
Uh hidden in this presentation is really
[30:49] (1849.60s)
like a billion-dollar question and it's
[30:51] (1851.20s)
not hyperbole. uh Yan Lakun's raise of
[30:53] (1853.60s)
$1.03 billion dollars back in March
[30:55] (1855.60s)
basically just to train world models is
[30:57] (1857.36s)
sort of what this presentation is about.
[30:58] (1858.64s)
I want to get at some of the questions
[31:00] (1860.24s)
that they're going to be testing. First
[31:02] (1862.08s)
five slides here just going to do some
[31:03] (1863.44s)
basics on world models. I think we've
[31:04] (1864.72s)
all heard the term but I want to just
[31:06] (1866.08s)
make sure we're all on the same page and
[31:07] (1867.36s)
then we'll jump into uh what this paper
[31:10] (1870.08s)
is really uh offering and what it means
[31:11] (1871.68s)
for world models at large. But first of
[31:14] (1874.24s)
all, world models, what are they? Why do
[31:16] (1876.00s)
we care about them? So really it's about
[31:17] (1877.44s)
learning the dynamics of the world,
[31:18] (1878.96s)
which is to say we're trying to come up
[31:20] (1880.64s)
with some model Typically, we're using
[31:22] (1882.56s)
like a big neural network to predict how
[31:24] (1884.08s)
a system will change over time based on
[31:25] (1885.68s)
its inputs. So, you have your current
[31:27] (1887.60s)
state or scenario using S for notation
[31:29] (1889.92s)
here. You're playing some action, maybe
[31:31] (1891.52s)
that's like a movement or a command for
[31:33] (1893.28s)
a robot, um, or a language command for a
[31:35] (1895.36s)
robot, and then you're trying to predict
[31:36] (1896.56s)
like what its outcome is going to be,
[31:38] (1898.24s)
like what scenario will it end up in
[31:39] (1899.76s)
once it's executed that action. So,
[31:41] (1901.68s)
you're really trying to model the system
[31:42] (1902.96s)
or the environment that the robot is in,
[31:44] (1904.64s)
modeling the world. It's a world model.
[31:46] (1906.88s)
Uh, these kinds of models are really
[31:48] (1908.08s)
cool. They enable a few really
[31:49] (1909.76s)
interesting capabilities. One of them is
[31:51] (1911.44s)
generating imagined outcomes. We've
[31:53] (1913.12s)
probably all seen like the sort of weird
[31:55] (1915.60s)
kind of um hallucinity uh imagination
[31:58] (1918.56s)
sequences coming out of world models
[31:59] (1919.92s)
over the last couple years. We'll talk
[32:01] (1921.12s)
more about those and why they're useful.
[32:02] (1922.72s)
Uh this allows us to get to model based
[32:04] (1924.64s)
control. I'm glad Stannis kind of
[32:06] (1926.32s)
explained that in the last talk for me,
[32:07] (1927.68s)
so I'll skip over it. Um and the last
[32:09] (1929.44s)
piece is really cool. Surprise
[32:10] (1930.56s)
quantification. Uh I'll get to that
[32:12] (1932.40s)
later. Um but a really powerful
[32:13] (1933.68s)
capability of world models. I wanted to
[32:15] (1935.68s)
communicate to you all that this is not
[32:17] (1937.20s)
a new idea at all. It's really just kind
[32:18] (1938.88s)
of new advertising or packaging on an
[32:21] (1941.52s)
old idea. So I started going back
[32:22] (1942.96s)
through Google Scholar and this is a
[32:24] (1944.48s)
paper that I think is older than the
[32:25] (1945.76s)
average age of this room. Um from
[32:27] (1947.76s)
Europe's 1990 and of course Richard S.
[32:29] (1949.92s)
Sutton who we know from reinforcement
[32:31] (1951.68s)
learning basically describes exactly a
[32:33] (1953.44s)
modern world model a black box that
[32:34] (1954.80s)
takes as input its situation and its
[32:36] (1956.88s)
action that it's going to execute and
[32:38] (1958.24s)
outputs a prediction of its immediate
[32:39] (1959.84s)
next situation. So really really old
[32:41] (1961.76s)
idea and uh that's the flyer from
[32:43] (1963.28s)
Europe's 1990.
[32:45] (1965.68s)
Great. Right. So, getting a little bit
[32:46] (1966.56s)
more explicit um and changing the
[32:47] (1967.76s)
notation from state to observation just
[32:49] (1969.52s)
because in real world systems, we
[32:50] (1970.96s)
typically don't have access to the exact
[32:52] (1972.48s)
true state. We typically have some
[32:53] (1973.68s)
observation from sensors. This is just
[32:55] (1975.52s)
an example that I pulled up from some
[32:56] (1976.72s)
world models that we're training on a
[32:58] (1978.24s)
quadrotor. So, as an example, the
[33:00] (1980.64s)
observation that the quadrotor gets
[33:02] (1982.00s)
might be its current kinematic state,
[33:03] (1983.60s)
position, velocity, this kind of thing.
[33:05] (1985.04s)
In addition to the images that it's
[33:06] (1986.40s)
taken from a forward- facing camera, the
[33:08] (1988.00s)
action might be a control input, in this
[33:09] (1989.84s)
case a yaw, and move back to the left.
[33:11] (1991.76s)
And then we want to make a prediction
[33:13] (1993.04s)
that says well if you do that action
[33:14] (1994.48s)
you're going to end up slightly back in
[33:16] (1996.32s)
the room and looking to the left. And we
[33:17] (1997.84s)
actually want to generate what the
[33:19] (1999.20s)
sensor um would result uh in in this
[33:21] (2001.36s)
case. So highly uh dimensional
[33:23] (2003.04s)
observations images uh and also LAR and
[33:25] (2005.76s)
things like that are completely on the
[33:26] (2006.80s)
table in world models. Uh they're really
[33:28] (2008.88s)
challenging because action sequences can
[33:30] (2010.56s)
be quite long. Um and the really big
[33:32] (2012.24s)
thing is that the minimum in the
[33:33] (2013.44s)
optimization landscape for these kinds
[33:34] (2014.96s)
of models may not correspond to the
[33:36] (2016.72s)
desired behavior. And more on that
[33:38] (2018.32s)
later. Um, but hopefully you'll agree
[33:40] (2020.00s)
that if you have trained a system that's
[33:41] (2021.28s)
capable of doing this thing, it must
[33:43] (2023.12s)
have an internal model of the world. And
[33:44] (2024.96s)
imbuing agents with an internal model of
[33:46] (2026.64s)
the world, um, is potentially a very
[33:48] (2028.88s)
useful capability. And that really is
[33:50] (2030.56s)
the big question. Are we going to have
[33:52] (2032.24s)
model free or model based policies? Are
[33:54] (2034.80s)
our agents going to have an internal
[33:55] (2035.92s)
model of the world or are they not? And
[33:57] (2037.52s)
this is sort of being fought out right
[33:58] (2038.80s)
now both in the research community and
[34:00] (2040.56s)
in like the startup community. So on the
[34:02] (2042.56s)
left, model free. The idea is you're
[34:04] (2044.64s)
taking some observations, you're feeding
[34:06] (2046.40s)
this into some kind of big neural
[34:08] (2048.08s)
network potentially with a bunch of
[34:09] (2049.28s)
interesting learning tricks there, but
[34:10] (2050.88s)
you're getting some optimal action out.
[34:12] (2052.32s)
So, it's just mapping between
[34:13] (2053.44s)
observation and some optimal action. But
[34:15] (2055.60s)
at no point is there an explicit
[34:17] (2057.20s)
representation of what the future might
[34:18] (2058.64s)
look like if you execute that action.
[34:20] (2060.24s)
These kinds of models are pretty good.
[34:22] (2062.16s)
There is growing evidence to show that
[34:24] (2064.80s)
internal to these neural networks are
[34:26] (2066.64s)
highly obuscated and challenging to
[34:28] (2068.32s)
interpret world models uh sort of in the
[34:30] (2070.56s)
in the weights. uh I'll talk about a
[34:32] (2072.88s)
paper very briefly that's um speaks to
[34:35] (2075.04s)
that and maybe someone can present on it
[34:36] (2076.40s)
in a future week. And then over on the
[34:38] (2078.32s)
um other side, model based approaches,
[34:39] (2079.84s)
right? So now we're saying we're going
[34:40] (2080.80s)
to train this world model up explicitly
[34:42] (2082.56s)
and actually use that in our policy to
[34:45] (2085.20s)
be able to explicitly predict the
[34:46] (2086.72s)
outcome of potential actions. So yeah,
[34:48] (2088.96s)
totally like two different species of
[34:50] (2090.72s)
policies. The model free stuff, some of
[34:52] (2092.40s)
the weaknesses is they show a little bit
[34:53] (2093.84s)
of brittleleness to out of distribution.
[34:56] (2096.08s)
Um, model based ones are great because
[34:57] (2097.68s)
you can kind of quantify modeling error
[34:59] (2099.52s)
and this is really important when you're
[35:00] (2100.72s)
deploying things in the real world. Uh,
[35:02] (2102.40s)
we'll talk a little bit about this. I
[35:03] (2103.52s)
have a little asterisk here, some
[35:04] (2104.64s)
biological precedent which we'll speak
[35:06] (2106.40s)
to more. Um, and you have to have this
[35:08] (2108.24s)
additional mechanism of course which is
[35:09] (2109.60s)
a downside where you actually need to
[35:11] (2111.28s)
propose action candidates to evaluate
[35:12] (2112.96s)
with the world model um, which Stannis
[35:15] (2115.20s)
spoke to in the previous talk. This is a
[35:17] (2117.28s)
great paper. But I just wanted to chuck
[35:18] (2118.32s)
this in there uh which talks about how
[35:20] (2120.00s)
even model free base policies do have
[35:22] (2122.40s)
world models in them and a really really
[35:24] (2124.32s)
cool paper that hopefully can be
[35:25] (2125.44s)
presented in a future week. Uh just to
[35:28] (2128.00s)
make it concrete before we jump into the
[35:29] (2129.44s)
paper I wanted to just bring a little
[35:31] (2131.20s)
toy here just to show you what this
[35:32] (2132.64s)
looks like. So of course went to push t
[35:34] (2134.56s)
like all good researchers do and in push
[35:36] (2136.08s)
t we basically just have an image of a
[35:37] (2137.84s)
little blue ball agent and you're trying
[35:39] (2139.44s)
to push the blue tea into the green
[35:40] (2140.88s)
slot. uh the state is comprised the
[35:43] (2143.04s)
observation is comprised of that image
[35:44] (2144.64s)
plus the 2D position of the endeector
[35:46] (2146.48s)
and the 2D action of where you're going
[35:47] (2147.84s)
to move the endector. So you can make a
[35:49] (2149.68s)
little architecture that looks like
[35:50] (2150.64s)
this. I just whipped this up. Couple
[35:52] (2152.00s)
hundred thousand parameters and um oh
[35:55] (2155.28s)
let's play this. So if that's the actual
[35:57] (2157.76s)
roll out, this is what the model thinks
[35:59] (2159.84s)
the action sequence is going to do. So
[36:02] (2162.64s)
you can see it's a little bit wobbly
[36:03] (2163.84s)
because it's a tiny model, but we can
[36:05] (2165.12s)
certainly train up models of these kinds
[36:06] (2166.72s)
of toy environments and indeed more
[36:08] (2168.24s)
complex ones. So what are the challenges
[36:10] (2170.00s)
associated with training this kind of
[36:11] (2171.28s)
model? Well, one is you're trying to
[36:12] (2172.80s)
learn the representation of the world.
[36:14] (2174.80s)
So how you're going to compactly
[36:16] (2176.08s)
represent those highly dimensional
[36:17] (2177.76s)
images or LAR inputs or highly
[36:20] (2180.16s)
dimensional sensor inputs at the same
[36:22] (2182.08s)
time as you're trying to learn how
[36:23] (2183.44s)
actions change that representation. So
[36:25] (2185.68s)
you're co-learning representation and
[36:27] (2187.76s)
dynamics. And there are many solutions
[36:30] (2190.32s)
in the optimization landscape that will
[36:32] (2192.80s)
essentially just cause you to do
[36:34] (2194.24s)
nothing. So for example a a local min
[36:36] (2196.64s)
minima in the optimization landscape is
[36:38] (2198.48s)
to say well every state is just the same
[36:40] (2200.24s)
it's a trivial collapse basically um and
[36:43] (2203.20s)
there are many techniques in the
[36:44] (2204.24s)
literature to say how can you avoid
[36:45] (2205.68s)
these so there are solutions of a
[36:47] (2207.92s)
variety different kinds that basically
[36:49] (2209.44s)
say there a way to avoid the collapse
[36:51] (2211.12s)
associated with training world models
[36:52] (2212.40s)
and that's really where the world model
[36:53] (2213.84s)
comes in. It says, well, instead of
[36:55] (2215.60s)
having to use some manner of trick or
[36:58] (2218.00s)
like special method or a bunch of like
[36:59] (2219.76s)
hyperparameter tuning schedule, we're
[37:01] (2221.52s)
instead going to really drastically
[37:03] (2223.04s)
simplify this and go for a more elegant
[37:04] (2224.56s)
method. So, if you know a little bit
[37:06] (2226.32s)
about world models, there's some popular
[37:08] (2228.08s)
ones in the top right here. This is a
[37:09] (2229.28s)
figure straight out of the paper. So,
[37:10] (2230.80s)
PLDM is planning in with latent dynamic
[37:12] (2232.80s)
models, dino, dino, um, distillation
[37:15] (2235.36s)
with no labels, world model, dreamer out
[37:17] (2237.20s)
of deep mind, and then temporal
[37:18] (2238.72s)
difference MPC as the final one. So, in
[37:20] (2240.96s)
some way, shape or form, I'll explain
[37:22] (2242.48s)
this. they use some kind of trick or um
[37:25] (2245.36s)
like challenging to configure design to
[37:27] (2247.44s)
get away with uh this collapse to avoid
[37:29] (2249.52s)
this collapse and the world models
[37:31] (2251.12s)
coming in and saying basically we can do
[37:33] (2253.20s)
this with sort of one hyperparameter and
[37:34] (2254.80s)
one loss term which I'll talk about
[37:36] (2256.96s)
there's really no time to go through all
[37:38] (2258.24s)
the different tricks that different
[37:39] (2259.68s)
world model approaches use because it
[37:41] (2261.76s)
really is the wild west out there right
[37:43] (2263.20s)
now so many different methods but they
[37:45] (2265.12s)
basically fall into one of these three
[37:46] (2266.40s)
categories so one is you could do some
[37:48] (2268.72s)
explicit heristic that stops collapse by
[37:50] (2270.56s)
like enforcing some special um
[37:53] (2273.12s)
healthiness in like the latent space of
[37:54] (2274.72s)
your embeddings. Um the language trick
[37:57] (2277.36s)
is maybe a bit unfair here, but it's
[37:58] (2278.64s)
what's used in the paper. Uh you could
[38:00] (2280.32s)
use some foundational methods. So you
[38:01] (2281.76s)
could take some like existing
[38:02] (2282.96s)
autoenccoder or diffusion model or video
[38:04] (2284.88s)
model and use that as a basis for your
[38:07] (2287.12s)
world model and add an action
[38:08] (2288.64s)
conditioning element in there. Um or you
[38:10] (2290.80s)
could use some privilege data that may
[38:12] (2292.40s)
not be usually available to the model
[38:14] (2294.16s)
outside of train time uh to be able to
[38:16] (2296.16s)
avoid collapse. and lay well model even
[38:18] (2298.72s)
though it says that it's doing something
[38:19] (2299.84s)
very different I really think uh it's
[38:21] (2301.68s)
just offering a new kind of trick uh
[38:23] (2303.52s)
which I'll talk about here so jer is
[38:25] (2305.44s)
joint embedding predictive architecture
[38:26] (2306.88s)
it's sort of yan lakun's main work and
[38:28] (2308.48s)
lay world model is a kind of jepper
[38:30] (2310.16s)
model uh basically the way it works is
[38:32] (2312.08s)
you're going to take an autoenccoder um
[38:34] (2314.40s)
or I should say an image encoder uh
[38:36] (2316.48s)
encode this observation in this case
[38:37] (2317.92s)
it's of a robot doing a push cube task
[38:40] (2320.40s)
that's going to turn that image into a
[38:41] (2321.84s)
latent vector in the latent space of
[38:44] (2324.08s)
this encoder uh you're going to train an
[38:46] (2326.00s)
action condition forecasting module this
[38:47] (2327.68s)
predictor to be able to predict what is
[38:49] (2329.36s)
the next latent embedding going to look
[38:51] (2331.28s)
like when I execute this action. So not
[38:53] (2333.12s)
what the next image is going to look
[38:54] (2334.16s)
like but what's the next latent going to
[38:55] (2335.44s)
look like and you can use the decoder
[38:57] (2337.68s)
attached to that encoder to decode that
[38:59] (2339.20s)
back out into a useful image. But for
[39:01] (2341.36s)
the most part all the interesting work
[39:02] (2342.56s)
is going to be done in the latent space.
[39:04] (2344.16s)
And basically what they say is over a
[39:05] (2345.60s)
batch all of those latent embeddings uh
[39:08] (2348.64s)
should be in a healthy distribution
[39:10] (2350.48s)
which they describe as a gausian
[39:12] (2352.56s)
distributed uh distribution in in the
[39:14] (2354.72s)
latent space and thus enters the sigg
[39:17] (2357.52s)
regularizer which is the sort of new
[39:18] (2358.88s)
term they add. So sigg for sketching as
[39:21] (2361.60s)
in uh doing one-dimensional passes over
[39:23] (2363.76s)
a high dimensional data. Um I for
[39:25] (2365.92s)
isotropic so this should look the same
[39:27] (2367.60s)
when you slice it in any direction and g
[39:29] (2369.20s)
for gaus and distributed cigar. So
[39:31] (2371.12s)
basically you're taking all of these
[39:32] (2372.08s)
embeddings of your different predictions
[39:34] (2374.64s)
doing a one-dimensional slice over each
[39:36] (2376.88s)
direction like in that highdimensional
[39:38] (2378.80s)
space and then you want each of the
[39:40] (2380.24s)
curves across those slices to be gausian
[39:42] (2382.56s)
distributed and if that's true then your
[39:45] (2385.44s)
um distribution in the latent space must
[39:47] (2387.36s)
be very healthy. Uh so the idea is you
[39:49] (2389.44s)
can quite cheaply evaluate how gausian
[39:51] (2391.36s)
distributed your embeddings are and thus
[39:52] (2392.96s)
how healthy your world model is and how
[39:54] (2394.72s)
non-olapsing it is. So essentially I
[39:57] (2397.04s)
just say instead of training up on the
[39:58] (2398.24s)
normal predict the next uh latent you
[40:00] (2400.56s)
add on this additional sigg term. So I'd
[40:03] (2403.04s)
argue that basically this paper is just
[40:04] (2404.72s)
um providing a very elegant kind of
[40:06] (2406.32s)
regularization. And to finish off I'll
[40:08] (2408.00s)
just talk about three capabilities that
[40:09] (2409.20s)
you get from this. So one is the
[40:10] (2410.96s)
openloop prediction quality. This is
[40:13] (2413.20s)
what world models do. So you feed in
[40:14] (2414.56s)
like the context this push t at the top
[40:16] (2416.64s)
and you can see the top row is the real
[40:18] (2418.32s)
example. The bottom is the imagined and
[40:20] (2420.08s)
they look about the same. This is good.
[40:21] (2421.44s)
It means your world model is really good
[40:22] (2422.72s)
at predicting what your next action is
[40:24] (2424.32s)
going to do. They do that on push t and
[40:25] (2425.84s)
then on a slightly um like a 3D analog
[40:28] (2428.08s)
task like a push cube. This is all
[40:30] (2430.64s)
great. I love seeing these um these
[40:32] (2432.24s)
plots. Um but really what matters is how
[40:34] (2434.96s)
does this actually affect the policy
[40:36] (2436.48s)
like for the actual task completion. How
[40:38] (2438.40s)
is this useful? Um and that sort of
[40:40] (2440.64s)
brings us into how you can use these
[40:41] (2441.76s)
models for model predictive control.
[40:43] (2443.92s)
Basically you take your initial
[40:45] (2445.12s)
observation and a goal observation. I
[40:47] (2447.60s)
put an asterisk there because how often
[40:48] (2448.96s)
do you have a goal observation in a
[40:50] (2450.24s)
robotics task? Like you don't always
[40:51] (2451.60s)
know exactly the situation that you want
[40:53] (2453.60s)
to end up in. But in this case, that's
[40:54] (2454.88s)
how they frame it. So they say, you
[40:56] (2456.08s)
know, the world looks like this right
[40:57] (2457.12s)
now. I want the world to look like this.
[40:59] (2459.04s)
You encode both of those. And then
[41:00] (2460.56s)
you're basically doing a search over the
[41:02] (2462.32s)
actions that will get you in the latent
[41:04] (2464.32s)
space from this starting point to this
[41:06] (2466.48s)
ending point. And there are well-
[41:07] (2467.60s)
definfined optimization methods to um to
[41:09] (2469.60s)
achieve that. It works pretty well. I'll
[41:11] (2471.92s)
make it um make it simple. The world
[41:14] (2474.16s)
model is better than the competition on
[41:15] (2475.92s)
these like small 2D tasks. As soon as
[41:17] (2477.92s)
you go to 3D, Dino World model wins. It
[41:20] (2480.00s)
does have a big foundational backbone
[41:21] (2481.44s)
trained on that kind of image data. So
[41:23] (2483.12s)
you'd expect it to um to win. Um they
[41:26] (2486.16s)
run on a really simple environment
[41:27] (2487.60s)
called two room and kind of say you know
[41:29] (2489.84s)
we don't do so well on this but that's
[41:31] (2491.44s)
because we're promoting like really high
[41:33] (2493.20s)
dimensional healthy embeddings and it's
[41:34] (2494.56s)
a very low dimensional problem. I'm not
[41:36] (2496.80s)
sure if I'd truly go for that. Um but a
[41:38] (2498.48s)
good takeway is that it's about 50 times
[41:40] (2500.00s)
faster than any of the competition
[41:41] (2501.52s)
across the board because it's doing all
[41:42] (2502.88s)
this work in the latent space and it
[41:44] (2504.56s)
doesn't have to have any like additional
[41:46] (2506.00s)
tricks relating to more forward passes
[41:48] (2508.00s)
or like having two copies of the model
[41:49] (2509.76s)
in memory. And uh you can actually boot
[41:51] (2511.52s)
this thing up on like a single card,
[41:52] (2512.88s)
less than 24 gigabytes of VRAM and it's
[41:54] (2514.88s)
only 15 million parameters. So that is
[41:56] (2516.40s)
pretty nice. Final piece, this is what I
[41:59] (2519.44s)
think is a really cool capability of
[42:01] (2521.04s)
world models. Um you can quantify the
[42:02] (2522.80s)
model error. So basically they just come
[42:04] (2524.64s)
up with some trajectories that kind of
[42:06] (2526.32s)
screw with the world model. So the top
[42:07] (2527.60s)
one is going from left to right. That's
[42:09] (2529.44s)
time. Uh so that's just like a nominal
[42:11] (2531.20s)
example. Everything's normal. Then they
[42:12] (2532.88s)
take the same example, but they change
[42:14] (2534.72s)
the color of the tea. And then they take
[42:16] (2536.08s)
the same example, but they just teleport
[42:17] (2537.52s)
the tea into a different location. And
[42:19] (2539.60s)
this is really cool because you can
[42:20] (2540.88s)
actually see the moment they apply those
[42:22] (2542.72s)
perturbations, you get a spike in the
[42:24] (2544.32s)
model error and this is detectable which
[42:26] (2546.00s)
is to say world model enabled agents can
[42:28] (2548.24s)
quantify how poor their predictions are.
[42:29] (2549.84s)
They have good estimates of their
[42:31] (2551.36s)
uncertainty. This is really powerful.
[42:33] (2553.12s)
Model freebased approaches don't
[42:34] (2554.40s)
natively give you this stuff.
[42:37] (2557.04s)
This is my last slide. Um a few
[42:38] (2558.80s)
discussion points and broader themes
[42:40] (2560.08s)
maybe we can chat about here. Obviously,
[42:41] (2561.84s)
you know, are we going to go with model
[42:42] (2562.96s)
based? Are we going to go with model
[42:44] (2564.16s)
free? Um what's going to be the best way
[42:45] (2565.68s)
to enable intelligent agents to do
[42:47] (2567.44s)
interesting things in the world?
[42:48] (2568.96s)
regularization and representation
[42:50] (2570.64s)
learning. Um, in this paper they are
[42:52] (2572.64s)
co-learning the representation of the
[42:55] (2575.04s)
world that the agent has and the
[42:56] (2576.24s)
dynamics of the world. Should this be
[42:58] (2578.08s)
separated? Can we take some bio
[42:59] (2579.52s)
inspiration? Should we use pre-existing
[43:01] (2581.92s)
um like foundation models and stuff like
[43:03] (2583.60s)
that? And then finally, how can we fight
[43:05] (2585.52s)
uh representational collapse elegantly?
[43:07] (2587.20s)
I think this work does a really great
[43:08] (2588.56s)
job of that, but the question is still
[43:10] (2590.24s)
out on what the best way to do it is. So
[43:12] (2592.48s)
um that's my talk. Thanks very much for
[43:14] (2594.08s)
your attention. [applause]
[43:21] (2601.12s)
All right.
[43:24] (2604.88s)
So, for the next two,
[43:27] (2607.68s)
um, we're kind of focusing on, um, less
[43:31] (2611.84s)
world model stuff and more heady, high
[43:35] (2615.12s)
level stuff that I think is pretty
[43:36] (2616.72s)
interesting. Um, this is a a paper
[43:40] (2620.32s)
that's going to be presented by Ashe,
[43:41] (2621.76s)
one of the YC uh, startups here named
[43:44] (2624.80s)
QABs. and your co-founder president.
[43:47] (2627.92s)
You're president of QABs. Is that right?
[43:49] (2629.84s)
>> Okay. Welcome Ashe.
[43:53] (2633.78s)
[applause]
[43:54] (2634.48s)
>> Hey everybody. Today I'm going to be
[43:56] (2636.08s)
talking through Andrew Gordon Wilson's
[43:58] (2638.16s)
paper uh deep learning is not so
[44:00] (2640.24s)
mysterious or different. Uh we actually
[44:02] (2642.40s)
work with Andrew on the generalization
[44:04] (2644.08s)
problem at Q Labs. So I'm really excited
[44:06] (2646.16s)
for more people to know about his work.
[44:07] (2647.84s)
The current state of machine learning is
[44:09] (2649.60s)
that we know that scaling that scaling
[44:11] (2651.60s)
models leads to better generalization.
[44:14] (2654.00s)
But we don't have a mechanistic
[44:15] (2655.44s)
understanding of why that is the case.
[44:18] (2658.00s)
Um yeah, if we can understand general
[44:20] (2660.80s)
generalization, then we might be able to
[44:23] (2663.28s)
optimize for it as well. So the payoff
[44:24] (2664.96s)
to understanding it is actually really
[44:27] (2667.04s)
really large. Um when you talk to people
[44:29] (2669.36s)
in the field, they often explain that
[44:31] (2671.52s)
generalization is a mystery and they
[44:33] (2673.36s)
point to examples like
[44:34] (2674.56s)
overparameterization,
[44:36] (2676.08s)
benign overfitting and and double
[44:38] (2678.32s)
descent as reasons why we might not be
[44:40] (2680.80s)
able to understand generalization at
[44:42] (2682.80s)
all. So Andrew's work here basically
[44:45] (2685.52s)
dispels those mysteries by using
[44:47] (2687.68s)
classical theories of generalization uh
[44:49] (2689.84s)
which which have to date not really been
[44:51] (2691.92s)
used to explain things like like
[44:53] (2693.52s)
overparameterization thus far. So the
[44:56] (2696.72s)
first classical theory that we'll go
[44:58] (2698.08s)
through is uh pack bay. So pack bay
[45:00] (2700.88s)
basically bounds the test loss which is
[45:02] (2702.88s)
the generalization. This is the quantity
[45:04] (2704.72s)
that we care about with a training loss
[45:06] (2706.64s)
and a compression term. Um the thing is
[45:09] (2709.52s)
in the past when people overparameterize
[45:12] (2712.16s)
models this compression term tends to
[45:14] (2714.16s)
dominate and so in practice these bounds
[45:16] (2716.64s)
become loose and vacuous meaning that we
[45:18] (2718.96s)
can't use them for anything at all. This
[45:21] (2721.36s)
was basically due to a mislication of
[45:23] (2723.36s)
the bound. You can compute the the
[45:25] (2725.36s)
compression term in an alternative way
[45:27] (2727.20s)
as we'll get into sort of later in the
[45:29] (2729.28s)
talk here. So let's go through the first
[45:31] (2731.92s)
mystery that uh Andrew goes through in
[45:34] (2734.00s)
his paper. Um the the mystery that he
[45:36] (2736.80s)
talks about is overparameterization. And
[45:39] (2739.28s)
this is basically the idea that as you
[45:41] (2741.20s)
scale up the the model parameter size
[45:43] (2743.20s)
from the bias various variance
[45:44] (2744.96s)
trade-off, you would expect that you
[45:47] (2747.52s)
might overfit. But in practice, we see
[45:49] (2749.44s)
the opposite. The scaling laws tell us
[45:51] (2751.12s)
that we actually get better
[45:52] (2752.48s)
generalization. Um the the the scaling
[45:56] (2756.16s)
and the better generalization from
[45:57] (2757.44s)
overparameterization is is is due to
[45:59] (2759.76s)
like the the the massive gains in model
[46:01] (2761.76s)
capability over the last couple of
[46:03] (2763.28s)
years. But we still don't really
[46:04] (2764.88s)
understand why it impro why it improves
[46:07] (2767.28s)
generalization.
[46:09] (2769.52s)
So the packbased framework gives us a
[46:11] (2771.68s)
pretty useful way to think about the
[46:13] (2773.60s)
success of over par parameterization.
[46:15] (2775.92s)
The first is with empirical risk.
[46:17] (2777.44s)
Empirical risk is basically training
[46:18] (2778.80s)
loss. When you increase the number of
[46:20] (2780.32s)
parameters you can fit your data better.
[46:22] (2782.24s)
Um so the empirical risk the left uh the
[46:24] (2784.96s)
first term goes down.
[46:27] (2787.52s)
And Andrew's work also finds that when
[46:30] (2790.64s)
we increase the model, when we increase
[46:32] (2792.80s)
the number of parameters, um we also
[46:35] (2795.20s)
find more compressible solutions. So
[46:37] (2797.20s)
this is work by Lotfi at all at all and
[46:39] (2799.52s)
they develop methods to basically
[46:41] (2801.44s)
compress the uh yeah they compress the
[46:44] (2804.32s)
the training set you and and and the
[46:46] (2806.80s)
model and they basically find a negative
[46:48] (2808.64s)
correlation between the the bits
[46:51] (2811.20s)
required to encode the training set and
[46:52] (2812.80s)
the number of parameters. Um and so we
[46:55] (2815.68s)
find that as we increase the model size
[46:57] (2817.68s)
we can find more efficient encodings of
[47:00] (2820.00s)
the training set. So the the second term
[47:03] (2823.20s)
in this bound also gets lower.
[47:07] (2827.20s)
Another perspective on this model
[47:08] (2828.80s)
compressibility point is a perspective
[47:10] (2830.48s)
of flatness. As you increase the number
[47:12] (2832.64s)
of parameters, it turns out that the
[47:14] (2834.64s)
number of the volume of flat minima in
[47:17] (2837.52s)
parameter space exponentially increases.
[47:19] (2839.92s)
This is the green region and uh and
[47:22] (2842.40s)
comparatively the the volume of sharp
[47:24] (2844.40s)
minima increases much less and uh this
[47:27] (2847.36s)
is interesting and this is useful the
[47:29] (2849.20s)
compressibility view because flat minima
[47:31] (2851.28s)
are known to be more compressible than
[47:33] (2853.12s)
sharp minima and so overparameterization
[47:36] (2856.16s)
fits within existing theories and
[47:38] (2858.40s)
through Andrew's work we actually see
[47:40] (2860.32s)
useful bounds on generalization even for
[47:42] (2862.96s)
models at at like a billion parameter
[47:44] (2864.88s)
scale and so we go to the next so-called
[47:47] (2867.28s)
mystery of deep learning which is called
[47:49] (2869.44s)
uh benign overfitting which Andrew also
[47:51] (2871.36s)
dispels in or at least partially
[47:53] (2873.44s)
explains in his paper. So the idea of
[47:56] (2876.00s)
benign overfitting is that deep neural
[47:58] (2878.08s)
networks are able to fit totally random
[48:00] (2880.16s)
noise but at the same time they are able
[48:02] (2882.24s)
to to to generalize well when you have
[48:04] (2884.48s)
structured data. The mystery is how can
[48:07] (2887.12s)
you have an inductive bias that allows
[48:09] (2889.20s)
you to generalize well if you can also
[48:11] (2891.20s)
fit totally random data. I think a
[48:13] (2893.44s)
regularized polomial model um in
[48:15] (2895.44s)
Andrew's paper gives us pretty good
[48:17] (2897.12s)
intuition for how this might be the
[48:18] (2898.64s)
case. Here you can see that on random
[48:21] (2901.28s)
data, so section C of the figure that we
[48:23] (2903.84s)
have enough parameters to fit the data
[48:25] (2905.60s)
and so we we can we can fit the totally
[48:28] (2908.24s)
random data. But on structured data, the
[48:30] (2910.56s)
the regularization pushes us to use the
[48:32] (2912.48s)
lower order terms. And so we are able to
[48:34] (2914.72s)
both get the flexibility but also have
[48:37] (2917.36s)
inductive bias that allows us to
[48:38] (2918.88s)
generalize. And generally this is this
[48:41] (2921.68s)
is the view to take um for for neural
[48:44] (2924.96s)
networks like there are expressive
[48:47] (2927.36s)
models with a soft inductive bias. Um we
[48:50] (2930.08s)
can go through this concept um just
[48:51] (2931.92s)
using this figure right here. So uh on
[48:54] (2934.40s)
the left hand side we have an example of
[48:56] (2936.56s)
of what's like a flexible hypothesis
[48:58] (2938.72s)
space. And a flexible hypothesis space
[49:00] (2940.80s)
would allow you to fit the data that you
[49:02] (2942.48s)
have. But the problem is that you would
[49:04] (2944.40s)
almost certainly overfit if you if you
[49:07] (2947.20s)
um if you do not have a bias towards one
[49:10] (2950.00s)
solution over the other. But on the
[49:11] (2951.84s)
other hand, if you have an inductive
[49:13] (2953.28s)
bias, you would solve this overfitting
[49:15] (2955.04s)
problem, but instead you wouldn't you
[49:17] (2957.52s)
wouldn't be able to model all of the
[49:18] (2958.96s)
details of reality. Um and so the middle
[49:21] (2961.52s)
ground is to have a very expressive
[49:23] (2963.76s)
hypothesis space, but also have a bias
[49:26] (2966.16s)
towards solutions that might generalize.
[49:28] (2968.32s)
For example, in the pack bay framework,
[49:30] (2970.40s)
we might want to bias towards more
[49:32] (2972.56s)
compressible models if we can. And so we
[49:34] (2974.72s)
see that uh deep learning so-called
[49:36] (2976.64s)
mysteries are actually consistent and
[49:38] (2978.72s)
partially explained by existing theories
[49:40] (2980.88s)
such as soft inductive biases and pack
[49:45] (2985.04s)
And sort of the thing I want to leave
[49:46] (2986.16s)
you with is that um if if we can find
[49:48] (2988.88s)
the right inductive biases building on
[49:51] (2991.20s)
these theories, we might be able to
[49:52] (2992.96s)
optimize for them as well. And by the no
[49:55] (2995.44s)
free lunch theorem, the only way that we
[49:57] (2997.20s)
get improvements in learning efficiency
[49:58] (2998.88s)
is through inductive biases. So I I
[50:01] (3001.12s)
think that this is that working on this
[50:02] (3002.72s)
problem is is a really good bet to make.
[50:04] (3004.88s)
Given the massive sample efficiency gap
[50:06] (3006.96s)
between AI and humans, we might actually
[50:08] (3008.80s)
see massive gains in capability. If we
[50:11] (3011.20s)
work on this problem um and so yeah,
[50:13] (3013.52s)
that's where I want to leave you with
[50:14] (3014.80s)
short presentation.
[50:16] (3016.96s)
[applause]
[50:19] (3019.76s)
Okay. Um so for this last paper then
[50:23] (3023.04s)
after this we have some boba for
[50:24] (3024.80s)
everyone. So sit tight 15 minutes. Um
[50:30] (3030.24s)
this is an idea that you know I've been
[50:32] (3032.40s)
obsessed with. Back to the sample
[50:33] (3033.84s)
efficiency thing. I think that like the
[50:35] (3035.20s)
two major problems we have left really
[50:36] (3036.40s)
to solve in in AI is intelligence per
[50:38] (3038.56s)
watt um and intelligence per sample. And
[50:41] (3041.04s)
if you compare that to to where we're at
[50:42] (3042.80s)
today compared to humans, um I would say
[50:45] (3045.28s)
that we're still or an order or two
[50:48] (3048.24s)
magnitude off on intelligence per watt.
[50:50] (3050.48s)
Uh and we're me like orders of magnitude
[50:53] (3053.20s)
off on intelligence per sample. I don't
[50:55] (3055.04s)
know what percent of the internet that
[50:56] (3056.48s)
you guys have read, but I have not read
[50:58] (3058.24s)
the entire internet. In Chris Ray's lab
[50:59] (3059.92s)
in particular, we've been obsessed with
[51:01] (3061.12s)
this idea that um if I have uh under the
[51:05] (3065.84s)
the a fixed size amount of data and I
[51:08] (3068.80s)
have infinite compute, just go nuts, how
[51:10] (3070.72s)
much generalization can I actually
[51:12] (3072.08s)
achieve? And so this is exactly uh the
[51:15] (3075.04s)
paper that starts to answer that
[51:16] (3076.48s)
question. And I'm really excited to uh
[51:18] (3078.48s)
introduce uh Con Woo.
[51:24] (3084.80s)
>> Uh hi, I'm Ku. Um this is a paper that I
[51:28] (3088.64s)
co-led with my amazing collaborator
[51:30] (3090.80s)
Suhas as well as Percy and Potsu.
[51:35] (3095.12s)
So part of the motivation for this paper
[51:37] (3097.28s)
is just the fact that over the past uh
[51:40] (3100.32s)
six or seven years pre-training has
[51:42] (3102.56s)
continued to improve model capabilities
[51:44] (3104.32s)
in pretty surprising ways. So in 2020
[51:47] (3107.92s)
with GPT3 we had sort of the emergence
[51:50] (3110.88s)
of incontext learning. In 2022 with
[51:54] (3114.16s)
Anthropics RHF, we had sort of the
[51:57] (3117.04s)
advent of alignment. And maybe most
[51:59] (3119.44s)
notably in 2024 with both 01 from OpenAI
[52:03] (3123.52s)
and then later Deepseek R1, we had the
[52:05] (3125.68s)
emergence of reasoning. And in fact,
[52:08] (3128.00s)
even still today, we see that with these
[52:10] (3130.08s)
newer and bigger pre-training runs like
[52:12] (3132.72s)
Mythos and 5.5, the models just continue
[52:16] (3136.16s)
to keep better. And so because
[52:17] (3137.44s)
pre-training is very expensive, a lot of
[52:19] (3139.84s)
the focus on the research side of things
[52:22] (3142.16s)
has been on how do we improve compute
[52:23] (3143.84s)
efficiency. And in general, people have
[52:26] (3146.24s)
found that to improve compute
[52:27] (3147.44s)
efficiency, you need to scale both the
[52:30] (3150.16s)
number of parameters in your model and
[52:32] (3152.08s)
the number of data points that you train
[52:33] (3153.36s)
your model on. And so these were
[52:35] (3155.52s)
quantified with the so-called chinchilla
[52:37] (3157.44s)
scaling laws. The problem with compute
[52:39] (3159.04s)
efficiency is that we're soon going to
[52:40] (3160.80s)
be constrained by data. And so if you
[52:43] (3163.12s)
look at these sort of public projections
[52:44] (3164.88s)
of the rate of growth of internet data,
[52:47] (3167.44s)
they suggest that the amount of sort of
[52:49] (3169.52s)
human generated text on the internet
[52:51] (3171.52s)
grows by roughly 3% per year. And the
[52:54] (3174.48s)
amount of compute that we're spending on
[52:55] (3175.92s)
pre-training is growing by roughly four
[52:58] (3178.64s)
or 5x per year. And so what this
[53:01] (3181.76s)
suggests is that as time passes on, the
[53:05] (3185.44s)
amount of compute that we're willing to
[53:07] (3187.20s)
spend per data point is going to
[53:09] (3189.20s)
continue to increase by roughly 4x
[53:11] (3191.12s)
year-over-year. And so this sort of
[53:12] (3192.64s)
motivates the core question in this
[53:14] (3194.48s)
paper which is how should you approach
[53:16] (3196.72s)
pre-training when you're constrained by
[53:18] (3198.72s)
data but totally unconstrained by
[53:20] (3200.72s)
compute. And it's worth maybe spending a
[53:23] (3203.84s)
few seconds to think for yourself if you
[53:25] (3205.92s)
haven't already seen this paper like
[53:27] (3207.52s)
what would you do in this situation.
[53:29] (3209.12s)
This is a very different algorithmic
[53:30] (3210.80s)
regime from sort of the computer
[53:32] (3212.48s)
efficient pre-training world that we've
[53:34] (3214.48s)
sort of lived in for sort of most of uh
[53:37] (3217.60s)
uh modern time. And it's also worth
[53:39] (3219.76s)
noting that this question is not that
[53:41] (3221.84s)
different from how machine learning
[53:43] (3223.60s)
worked before the modern alm. So for
[53:47] (3227.04s)
things like classical statistics where
[53:48] (3228.80s)
maybe you really care about your rates
[53:50] (3230.40s)
with respect to the number of points of
[53:51] (3231.68s)
data you have and you don't care about
[53:53] (3233.28s)
compute or even older benchmarks like
[53:55] (3235.84s)
emnest and pen treebank where you're
[53:57] (3237.76s)
sort of implicitly data constrained
[53:59] (3239.44s)
because the benchmarks don't have that
[54:01] (3241.04s)
many data points.
[54:03] (3243.76s)
And so sort of the core contribution
[54:05] (3245.52s)
that I'll explain in this paper is that
[54:08] (3248.32s)
we bring the modern toolkit of scaling
[54:10] (3250.48s)
laws to to sort of answer this problem.
[54:13] (3253.52s)
And so what we'll show is that we'll
[54:15] (3255.04s)
propose a few different scaling recipes
[54:17] (3257.84s)
and we'll sort of chase scaling recipes
[54:20] (3260.56s)
that monotonically decrease your iid
[54:23] (3263.36s)
validation laws. So sort of in
[54:24] (3264.72s)
distribution generalization and we'll
[54:26] (3266.88s)
show that these scaling laws have a
[54:28] (3268.64s)
really clean functional form and they
[54:30] (3270.08s)
follow a super clean power law. And when
[54:31] (3271.92s)
you're able to fit these power laws,
[54:33] (3273.60s)
what you can do is you can estimate the
[54:35] (3275.68s)
best possible loss of your recipe by
[54:38] (3278.24s)
looking at the asmtote of the power law.
[54:40] (3280.40s)
And this is in some sense a
[54:41] (3281.68s)
quantification of your best possible
[54:43] (3283.36s)
performance under infinite compute. And
[54:46] (3286.48s)
our goal in this paper is sort of to
[54:48] (3288.40s)
think more carefully about what types of
[54:50] (3290.40s)
algorithms allow you to lower your
[54:52] (3292.64s)
compute asmtote. Uh and we're sort of
[54:54] (3294.96s)
going to chase these types of infinite
[54:56] (3296.48s)
compute wins. And so to start, I'm going
[54:58] (3298.48s)
to introduce this canonical setting that
[55:00] (3300.16s)
we referenced in this paper, which is
[55:02] (3302.00s)
that we're going to simulate a data
[55:03] (3303.44s)
constrained world by just constraining
[55:05] (3305.44s)
the number of pre-training tokens we
[55:06] (3306.80s)
have to be a very small amount. So we're
[55:08] (3308.72s)
going to assume access to only 200
[55:10] (3310.32s)
million tokens from DCLM, which is
[55:12] (3312.24s)
general web data. And what we're going
[55:14] (3314.48s)
to do is we're going to pre-train large
[55:16] (3316.40s)
and larger models, which is the x-axis,
[55:18] (3318.72s)
using different kinds of pre-training
[55:20] (3320.16s)
recipes. And the y-axis here is going to
[55:22] (3322.72s)
be again our ID validation loss on DS
[55:25] (3325.60s)
DCLM. And our goal is going to be to
[55:28] (3328.40s)
find recipes that allow us to spend more
[55:30] (3330.56s)
compute and train larger models while
[55:32] (3332.64s)
monotonically decreasing our loss. So to
[55:34] (3334.80s)
start, we can consider sort of the
[55:36] (3336.08s)
obvious approach that you might take
[55:37] (3337.28s)
when you're in this setting, which is
[55:39] (3339.12s)
first to epoch your data. So to train on
[55:41] (3341.20s)
the same data points over and over again
[55:43] (3343.28s)
until you start overfitting as well as
[55:45] (3345.60s)
scaling up your model. So making your
[55:47] (3347.04s)
model larger and larger. And what we can
[55:49] (3349.12s)
do is we can do both of these at the
[55:50] (3350.72s)
same time. And we can do sort of an
[55:52] (3352.64s)
exhausted grid search over these
[55:54] (3354.16s)
parameters until we start over until we
[55:56] (3356.32s)
start overfitting and then we do early
[55:57] (3357.76s)
stopping. And this is sort of the red
[55:59] (3359.76s)
line which is what we call the standard
[56:01] (3361.28s)
recipe. And what you'll see with the
[56:03] (3363.20s)
standard recipe is that even if you are
[56:05] (3365.60s)
willing to spend more compute, as you
[56:08] (3368.08s)
train more and more overparameterized
[56:10] (3370.00s)
models, you start to overfit more
[56:12] (3372.00s)
quickly and your loss starts to increase
[56:13] (3373.92s)
after a certain point.
[56:16] (3376.32s)
And so if you see this line, sort of the
[56:18] (3378.40s)
natural instinct you should have is how
[56:19] (3379.92s)
do we fix this? And one possible
[56:21] (3381.92s)
approach is to do really aggressive
[56:23] (3383.36s)
regularization. And so sort of the first
[56:25] (3385.84s)
baseline in this paper is going to be
[56:28] (3388.32s)
doing really aggressive regularization
[56:30] (3390.08s)
by cranking up your weight decay. And so
[56:32] (3392.48s)
what we do is we show that if you
[56:34] (3394.24s)
optimally tune your weight decay for
[56:36] (3396.32s)
each total parameter count. So we're
[56:38] (3398.88s)
going to optimally tune learning rate,
[56:40] (3400.56s)
weight decay, and epoch count for each
[56:42] (3402.08s)
one of these purple points. You can show
[56:44] (3404.24s)
that your loss follows a really clean
[56:46] (3406.32s)
power law as you increase the number of
[56:48] (3408.64s)
parameters in your model. And this is
[56:51] (3411.36s)
really aggressive regularization. So for
[56:53] (3413.36s)
context, we use weight decays that are
[56:55] (3415.52s)
something like 30 times larger than the
[56:57] (3417.44s)
weight decays that people do for compute
[56:58] (3418.96s)
optimal pre-training.
[57:00] (3420.88s)
And so on the legend here, you can see
[57:02] (3422.88s)
the the sort of the form of this power
[57:04] (3424.64s)
law. And it has a few nice properties.
[57:07] (3427.60s)
One is that the exponent on the model
[57:10] (3430.24s)
parameters n is one. And this is
[57:12] (3432.40s)
actually predicted by sort of the data
[57:14] (3434.40s)
constraint theory. The second nice
[57:16] (3436.48s)
property that it has is that the scaling
[57:18] (3438.64s)
law has an asmtote which is 3.43 in this
[57:21] (3441.92s)
case. And this characterizes the
[57:24] (3444.00s)
performance of the best possible
[57:25] (3445.68s)
regularized model in this setting if you
[57:28] (3448.00s)
had like infinite compute. So you'll
[57:30] (3450.64s)
notice that the baseline approaches
[57:32] (3452.56s)
because they overfit more quickly. They
[57:34] (3454.08s)
don't even have a measurable asmtote.
[57:35] (3455.76s)
And so once we start going down the
[57:37] (3457.04s)
rabbit hole of regularization and these
[57:38] (3458.96s)
other types of classical machine
[57:40] (3460.16s)
learning techniques, there's a whole
[57:42] (3462.16s)
basket of techniques to to get into. And
[57:44] (3464.80s)
so perhaps maybe the most famous one is
[57:46] (3466.80s)
to do ensembling.
[57:48] (3468.72s)
And so what we show in this paper is
[57:50] (3470.56s)
that you can bring back ensembling in
[57:52] (3472.80s)
the modern world of pre-training
[57:54] (3474.48s)
language models and they turn out to be
[57:56] (3476.40s)
incredibly data efficient. So what these
[57:59] (3479.04s)
light blue points correspond to is they
[58:01] (3481.52s)
correspond to 300 million parameter
[58:03] (3483.76s)
models that were ensembling with more
[58:06] (3486.32s)
and more members. So the fifth point
[58:08] (3488.72s)
will correspond to 1.5 total billion
[58:11] (3491.84s)
total parameters which is five five
[58:14] (3494.00s)
ensemble of 300 million parameter
[58:15] (3495.68s)
models. We show that you can also fit
[58:17] (3497.68s)
really clean scaling laws to ensembles.
[58:20] (3500.16s)
So you also get a power law that has
[58:22] (3502.08s)
exponent one and the number of ensemble
[58:23] (3503.68s)
members and it also has an asmtote. But
[58:26] (3506.64s)
most importantly the asmtote of
[58:28] (3508.56s)
ensembling is much lower than the
[58:30] (3510.56s)
asmtote of the regularized recipe. So
[58:32] (3512.80s)
it's giving you a true data efficiency
[58:34] (3514.48s)
win if you had an infinite amount of
[58:36] (3516.48s)
compute. There's also this interesting
[58:38] (3518.80s)
property which is that ensemblings if
[58:41] (3521.12s)
you do a compute matched comparison so
[58:42] (3522.80s)
the same number of parameters are
[58:44] (3524.40s)
actually better than the regularized
[58:46] (3526.08s)
recipe. So if your goal is just to train
[58:48] (3528.56s)
the best 1.5 billion parameter model
[58:51] (3531.44s)
it's better to train an ensemble of a
[58:53] (3533.12s)
bunch of small models when you're data
[58:54] (3534.40s)
constrained than to train one really
[58:56] (3536.32s)
large model. The last thing we show in
[58:58] (3538.32s)
this plot is that you can actually
[59:00] (3540.40s)
compose the benefits of regularization
[59:03] (3543.12s)
and ensembling. So one way to think
[59:05] (3545.36s)
about this is that regularization gives
[59:07] (3547.92s)
you this ability to continue to make the
[59:10] (3550.24s)
models larger and larger while
[59:12] (3552.56s)
ensembling introduces this new axis for
[59:15] (3555.04s)
scaling compute which is by training
[59:17] (3557.04s)
more and more models. And so what this
[59:19] (3559.68s)
gold line which we call the joint
[59:21] (3561.20s)
scaling recipe is we quantify this
[59:23] (3563.84s)
hypothetical performance if we were able
[59:26] (3566.16s)
to train an ensemble an infinitely large
[59:29] (3569.04s)
ensemble of infinitely large models. And
[59:32] (3572.00s)
so the way in which we actually quantify
[59:33] (3573.68s)
this performance is we fit two scaling
[59:36] (3576.80s)
laws. So we'll take a double limit. What
[59:39] (3579.36s)
we'll first do is we'll train ensembles
[59:41] (3581.92s)
of 150 million parameter models, 300
[59:44] (3584.48s)
million parameter models and so on and
[59:46] (3586.40s)
so forth. And then we'll look at the
[59:48] (3588.40s)
asmmptotes of the ensembles. And then
[59:50] (3590.48s)
we'll take a second we'll fit a second
[59:52] (3592.00s)
scaling law to the asmmptotes of these
[59:53] (3593.52s)
ensembles. And this is essentially
[59:55] (3595.68s)
taking the first limit is taking the
[59:57] (3597.28s)
limit over K. And the second limit is
[59:59] (3599.44s)
taking the limit over n. And what we
[60:01] (3601.84s)
find is that if you're willing to sort
[60:03] (3603.92s)
of go through the effort of training
[60:05] (3605.44s)
infinitely large models and infinitely
[60:07] (3607.92s)
many ensembles, uh you get a huge loss
[60:10] (3610.08s)
improvement. And so all of these
[60:11] (3611.68s)
experiments are sort of in this toy data
[60:13] (3613.76s)
constrained setup of 200 million tokens.
[60:15] (3615.92s)
And obviously this is very different
[60:17] (3617.36s)
from sort of the standard regime of
[60:18] (3618.64s)
pre-training. So what we also do in this
[60:21] (3621.12s)
paper is we spend some effort on trying
[60:22] (3622.88s)
to confirm that our recipes scale. So
[60:24] (3624.96s)
the first way in which we do this is
[60:26] (3626.40s)
that we build data scaling laws. So what
[60:28] (3628.64s)
data scaling laws are is that we repeat
[60:30] (3630.88s)
the exact same set of experiments from
[60:32] (3632.40s)
the previous slide at four different
[60:34] (3634.40s)
pre-training token counts up to 1.7
[60:36] (3636.96s)
billion uh tokens. And so for each slice
[60:40] (3640.24s)
on the x-axis at each seat token count,
[60:42] (3642.80s)
we're going to quantify the best
[60:44] (3644.16s)
possible performance of each recipe if
[60:46] (3646.56s)
we had an infinite amount of compute. So
[60:48] (3648.80s)
for the red points, they overfit more
[60:50] (3650.72s)
quickly. So these will be actual models.
[60:52] (3652.88s)
While for the purple and the gold
[60:54] (3654.40s)
points, these will correspond to sort of
[60:56] (3656.32s)
a single limit or a double limit. What
[60:58] (3658.24s)
these data scaling laws let us do is
[61:00] (3660.16s)
they let us quantify the data efficiency
[61:02] (3662.00s)
numbers of our approaches. So one way in
[61:04] (3664.96s)
which we do this is if we have some new
[61:06] (3666.64s)
recipe that we believe should improve
[61:08] (3668.72s)
upon the standard recipe that we're
[61:10] (3670.24s)
using right now, you can take the loss
[61:12] (3672.56s)
of your new recipe and you can project
[61:14] (3674.80s)
it onto the data scaling law. So the red
[61:17] (3677.12s)
line of a standard recipe and this
[61:19] (3679.52s)
projection lets you measure essentially
[61:21] (3681.52s)
the effective number of extra tokens
[61:23] (3683.52s)
that your algorith algorithmic
[61:25] (3685.04s)
improvement is buying you. So in this
[61:27] (3687.36s)
case what we see is that this joint
[61:29] (3689.36s)
scaling recipe gives you roughly a 5x
[61:31] (3691.68s)
data efficiency win over uh the the
[61:34] (3694.40s)
standard recipe. It's also worth noting
[61:36] (3696.80s)
that uh these data efficiency wins are
[61:39] (3699.12s)
something that we can realize with sort
[61:41] (3701.28s)
of finite models not just double limits.
[61:43] (3703.44s)
So for example if you're willing to
[61:44] (3704.72s)
train a five ensemble of 1 billion
[61:46] (3706.56s)
parameter models this will give you
[61:48] (3708.48s)
roughly a 3.7x data efficiency win. The
[61:50] (3710.88s)
other interesting aspect about these
[61:52] (3712.16s)
data scaling laws is if you look at the
[61:54] (3714.40s)
functional form in the legend, you'll
[61:56] (3716.32s)
see that they all have really similar
[61:57] (3717.84s)
exponents and they all have very similar
[61:59] (3719.44s)
asmtotes. And so the reason why this
[62:01] (3721.92s)
matters is this suggests that even if
[62:04] (3724.56s)
you repeated these experiments at a much
[62:06] (3726.96s)
much larger token scale, if you believe
[62:08] (3728.72s)
that these data scaling law laws
[62:10] (3730.40s)
extrapolate, this data efficiency win is
[62:13] (3733.12s)
going to be constant over the actual
[62:14] (3734.88s)
number of token counts that you have. So
[62:17] (3737.04s)
they suggest that this double joint
[62:19] (3739.68s)
scaling well recipe has a 5x data
[62:21] (3741.60s)
efficiency win even if you are willing
[62:23] (3743.52s)
to send the seed token count to like 10
[62:25] (3745.68s)
trillion tokens or whatever people are
[62:27] (3747.28s)
doing pre-training at these days. So now
[62:29] (3749.76s)
I'll go over some methods to sort of
[62:31] (3751.52s)
make this data efficiency win perhaps
[62:33] (3753.12s)
slightly more practical. And so even
[62:35] (3755.76s)
though these recipes require a lot of
[62:37] (3757.12s)
training compute we also show that you
[62:39] (3759.04s)
can reduce the amount of inference
[62:40] (3760.40s)
compute you need by using distillation.
[62:43] (3763.44s)
So the plot on the right here, the
[62:45] (3765.12s)
purple line corresponds to the same
[62:46] (3766.64s)
regularized recipe. The light blue
[62:48] (3768.88s)
points correspond to the same ensemble
[62:50] (3770.48s)
skilling. So we first show that what you
[62:52] (3772.96s)
can do is you can take an eight ensemble
[62:54] (3774.96s)
which is roughly 2.4 billion total
[62:56] (3776.72s)
parameters and you can distill it into a
[62:59] (3779.20s)
single dense 300 million parameter model
[63:01] (3781.36s)
which is the pink star in the bottom.
[63:03] (3783.92s)
And you can do this while retaining
[63:05] (3785.36s)
roughly 83% of the loss improvement. So
[63:08] (3788.88s)
this shows you that data efficiency is
[63:11] (3791.12s)
not something that you need a large
[63:13] (3793.52s)
amount of inference compute for. If
[63:15] (3795.44s)
you're willing to amort amortize the
[63:17] (3797.52s)
test time compute during training time,
[63:20] (3800.24s)
you can get an extremely data efficient
[63:21] (3801.92s)
model that's still very very small. The
[63:25] (3805.20s)
other surprising result we show in this
[63:26] (3806.72s)
section is that you can do
[63:28] (3808.24s)
self-distillation to even improve your
[63:30] (3810.08s)
loss. So with self-distillation, what
[63:32] (3812.24s)
we're doing is we're starting with the
[63:33] (3813.92s)
300 million parameter model at the start
[63:36] (3816.08s)
of the light blue curve and then we're
[63:38] (3818.32s)
distilling this model into a fresh 300
[63:41] (3821.04s)
million parameter model which is the
[63:42] (3822.48s)
green star. And what we find is very
[63:44] (3824.80s)
surprisingly even doing self
[63:46] (3826.48s)
distillation gives you huge loss
[63:47] (3827.84s)
improvement. It even beats the asmtote
[63:50] (3830.32s)
of the regularized recipe. This is
[63:52] (3832.08s)
actually pretty counterintuitive and we
[63:54] (3834.48s)
have a longer sort of uh description of
[63:56] (3836.96s)
this result in the paper but it turns
[63:58] (3838.96s)
out to have pretty surprising
[64:00] (3840.16s)
connections to uh ensembling and there's
[64:02] (3842.88s)
actually a view uh from prior work on
[64:05] (3845.44s)
viewing self-distillation as implicitly
[64:07] (3847.84s)
training a two ensemble. We also show
[64:09] (3849.68s)
that even though we're only chasing IID
[64:11] (3851.92s)
VAT loss in all of our experiments,
[64:14] (3854.32s)
pretty much all of the trends in this
[64:16] (3856.00s)
paper directly work on downstream
[64:18] (3858.40s)
benchmarks. And this is like a fully
[64:21] (3861.04s)
held out sort of test set where we only
[64:23] (3863.92s)
looked at the benchmarks at the very end
[64:25] (3865.28s)
of the paper because the advisers told
[64:26] (3866.72s)
us to. Um, and you can see that
[64:29] (3869.44s)
everything tracks the standard recipe
[64:31] (3871.60s)
overfits. Still model scaling gives you
[64:34] (3874.00s)
improvements. Ensembling is even better.
[64:36] (3876.80s)
and you can still retain a lot of the
[64:38] (3878.32s)
benefits through distillation. And
[64:39] (3879.92s)
finally, we also show that you can do
[64:41] (3881.68s)
this for other settings beyond
[64:43] (3883.12s)
pre-training. So things like continued
[64:44] (3884.48s)
pre-training. So we consider a setup
[64:46] (3886.72s)
where you're trying to CPT a 3B model
[64:49] (3889.84s)
and we assume access to sort of this
[64:52] (3892.16s)
restricted set of 4 billion math related
[64:54] (3894.56s)
tokens where the whole corpus of data is
[64:57] (3897.52s)
actually 73 billion tokens. And what we
[64:59] (3899.68s)
show is that if you're willing to do
[65:00] (3900.96s)
these data efficiency tricks like
[65:03] (3903.36s)
aggressive epoing and things like
[65:04] (3904.96s)
ensembling, you can match the
[65:07] (3907.04s)
performance of training on the full 73
[65:08] (3908.88s)
billion tokens even using only 4 billion
[65:11] (3911.68s)
tokens which is roughly a 17x data
[65:14] (3914.24s)
efficiency win. So to sort of wrap up
[65:16] (3916.24s)
this talk, maybe the main point I want
[65:18] (3918.64s)
to make is that when you're constrained
[65:20] (3920.80s)
by data and you're unconstrained by
[65:22] (3922.48s)
compute and this sort of new algorithmic
[65:24] (3924.24s)
regime, the types of algorithmic choices
[65:26] (3926.64s)
you make matter a lot and we should be
[65:28] (3928.96s)
willing to sort of rethink every aspect
[65:30] (3930.56s)
of a stack. In this paper, we mostly do
[65:33] (3933.28s)
this by revisiting a lot of these
[65:35] (3935.20s)
classical ideas from uh machine learning
[65:37] (3937.52s)
and deep learning. Things like
[65:39] (3939.04s)
regularization, ensembling, distillation
[65:42] (3942.16s)
have existed for for many many years.
[65:44] (3944.48s)
And we also introduced this evaluative
[65:46] (3946.72s)
tool of asmmptotes. And maybe the hope
[65:49] (3949.04s)
is that if you're willing to chase
[65:50] (3950.72s)
algorithms that have lower compute
[65:52] (3952.48s)
asmmptotes, uh these will give you like
[65:54] (3954.96s)
better ideas for data efficiency. But
[65:56] (3956.72s)
like ultimately what we really want to
[65:58] (3958.24s)
do is we want these asmtotes to help us
[66:00] (3960.40s)
develop new and better ideas under
[66:02] (3962.48s)
infinite compute that that don't already
[66:04] (3964.40s)
exist. And so if you're interested in
[66:06] (3966.56s)
the details, that's a QR code for the
[66:08] (3968.32s)
paper. And we've also done some
[66:10] (3970.16s)
follow-up work on looking at how
[66:11] (3971.68s)
synthetic data interacts with data
[66:13] (3973.20s)
efficiency. So feel free to check that
[66:14] (3974.88s)
out as well if you're interested.
[66:16] (3976.48s)
Thanks. [applause]
[66:22] (3982.96s)
>> All right. Thank you guys so much for
[66:24] (3984.88s)
coming. This is like a dream come true.
[66:26] (3986.88s)
I'm in one of my favorite places that um
[66:29] (3989.84s)
was most important places of my life and
[66:32] (3992.08s)
now I get to talk about AI here. So
[66:34] (3994.56s)
super super fun. I think there's a lot
[66:36] (3996.72s)
of potential for this club. I think I
[66:38] (3998.56s)
don't have nearly, you know, 1% of all
[66:41] (4001.76s)
the ideas that we probably have to make
[66:44] (4004.16s)
this club really great um in all of your
[66:46] (4006.80s)
heads. And so we want to make sure all
[66:49] (4009.44s)
of you guys get in on the Slack. So I'll
[66:51] (4011.12s)
make sure that you know, please send me
[66:52] (4012.48s)
a note if you're not already on there.
[66:54] (4014.08s)
And then we can kind of make this thing
[66:55] (4015.28s)
whatever we want. So it's kind of fun
[66:57] (4017.60s)
and I intend to. So like please come
[67:00] (4020.32s)
with ideas. We want to make this super
[67:01] (4021.92s)
fun. Um obviously, you know, there's
[67:03] (4023.92s)
some round rules, be respectful, all
[67:05] (4025.52s)
that kind of stuff. Um, and definitely
[67:07] (4027.20s)
be involved. And that's kind of the the
[67:08] (4028.64s)
the biggest thing that we really only
[67:10] (4030.56s)
really ask. That's all I got. That's a
[67:12] (4032.08s)
wrap. Go get some boba tea. Thank you.