How good of a BERT can one get in ONE DAY on ONE GPU?

With all the recent studies about scaling compute up, this paper takes a refreshing turn and does a deep dive into scaling down compute.

It's well written, stock full of insights. Here is my summary and my opinions.

https://arxiv.org/abs/2212.14034 by @jonasgeiping and @tomgoldstein

🧶 1/N

Cramming: Training a Language Model on a Single GPU in One Day

Recent trends in language modeling have focused on increasing performance through scaling, and have resulted in an environment where training language models is out of reach for most researchers and practitioners. While most in the community are asking how to push the limits of extreme computation, we ask the opposite question: How far can we get with a single GPU in just one day? We investigate the downstream performance achievable with a transformer-based language model trained completely from scratch with masked language modeling for a single day on a single consumer GPU. Aside from re-analyzing nearly all components of the pretraining pipeline for this scenario and providing a modified pipeline with performance close to BERT, we investigate why scaling down is hard, and which modifications actually improve performance in this scenario. We provide evidence that even in this constrained setting, performance closely follows scaling laws observed in large-compute settings. Through the lens of scaling laws, we categorize a range of recent improvements to training and architecture and discuss their merit and practical applicability (or lack thereof) for the limited compute setting.

arXiv.org
2/N First, the setting. See screenshot for full info, but in short:
- 24h of training on a single good GPU (2080ti or a4000)
- Transformer architecture, modifications are OK
- MLM training from scratch
- NO use of pre-trained anything in any way (except tokenizer)
- Any dataset

3/N implementation is relatively basic. I like that they refrain from jumping to specialized/optimized setups which would be starting in a local minimum right away.

Note the odd last paragraph which is greyed out. That's a weird latex mistake to make, isn't it?

4/N Data: en
- Short sequence length 128 and packing with <sep>. I like the simplicity!
- <cls> token seems unnecessary; we found the same with ViT
- Use large batches by accumulating gradients across micro-batches
- 1epoch (cc @aran)
- Grey text isn't error, but contains all negative results. The most interesting part!

5a Architecture (sub-thread).

Super interesting and echoes our experience in vision: with enough data, all variants reach ~ the same loss in the same wall-clock time. Faster models need to see more tokens. In other words, with good implementations, it's hard to cheat wall-clock.

5b same thing again: focus model changes to those that keep the same capacity (~params for Transformer MLMs with fixed seqlen) but speed things up.

It's a shame that vast majority of papers (including sometimes mine) completely ignore reporting wall-clock speed or slowdowns.

5c changes include:

- SA: remove biases, many variants tried, none kept.
- MLP: remove biases, make gated, nothing else.
- Scaled sin embedding + LN
- pre-norm helps, but only when increasing LR
- In the head, MLP can be dropped (same with ViT).
- Again: gray text interesting!

6a/N training

- Stick to the simplest MLM objective
- Optimizer: Adam. No win from fancier.
- I want to point out that AdaFactor is meant to save memory but behave like Adam, so no win is a win!
- They mention no win from Shampoo (cc @nan) but aren't confident it's a good impl.

6b training: lr schedule!

They tried many, but this is where I disagree with the paper.
Most schedules either don't warmup (-> lower peak lr!) or don't cooldown (-> 0 at the end).

The only two that work clearly better than the rest are the only two with warmup and cooldown!

addendum to 6b: the figure from my screenshot is in the appendix. Other papers have shown even pre-norm needs warmup.

6c training:
- no dropout, tokendrop, or length curric.
- micro-batch 96 accum into 1.5-4k, linearly increased during training. Auto-tuning looks mostly linear.

7/N data

- Try pile subsets, c4, book+wiki
- dedup (exact substring) not helpful
- remove uncompressible data "t=0.3": keep only if ntokens < 0.3 * nchars
- sort: data with fequent tokens first (think "easy/common text first")
- grow batch-size at end

8 results

left: overall, it's getting pretty close to original BERT which used 45-136x more total FLOPS (4d on 16 TPUs)
right: and when training for 16x longer (2d on 8 GPUs), the same recipe actually improves on original BERT quite a bit, reaching RoBERTa levels of performance.

9/9 final thoughts.

- I really like the "trend reversal" of seeing how much can be done with limited compute.
- I am a big fan of the gray text passages for things that were tried but didn't work.
- The lr sched part is fishy, but not super important.
- Impressive bibliography!

PS: This thread took me almost as long as a paper review. Looks like I procrastinate my CVPR reviews by making twitter paper reviews instead ¯\_(ツ)_/¯

Meta: I wrote the thread on twitter and copied it over here, hence the short toots. I tried, but the UI for writing threads here (web client) is absolutely abysmal...

@lb can we write threads here?
@gusthema I do this by writing the first post and then replying to it, with slightly changed visibility setting to "unlisted" to avoid spamming public feeds. I have no idea if that's the best way, but it's definitely not an enjoyable way!
@lb thanks
I'll try this next time!
@lb @gusthema That is indeed the recommended way. It's sub-ideal for both authors and followers, sadly.

@lb thanks. Much appreciated that you took the time.

You referred to reporting wall clock time. As a normalizer, like FLOPS?

Lucas Beyer on Twitter

“I beg the community to please stop using parameters as x axis. It is *especially* meaningless for ViT-style models: B/32 has *more* params than B/16, but is faster, less capacity, and performs worse. Use img/s ideally, or flops if need. (Not singling this paper, so so many!)”

Twitter
@lb Also, while listening to a podcast with Tri Dao I picked up on delays due to memory transfer not being reflected by FLOPS, but by wall clock time (e.g. flash attention). (And yes, using this abandoned Mastodon feed as a notebook now).