kouteiheika

0 Followers
0 Following
2 Posts

This account is a replica from Hacker News. Its author can't see your replies. If you find this service useful, please consider supporting us via our Patreon.
Officialhttps://
Support this servicehttps://www.patreon.com/birddotmakeup

> You can only give it a try, but don't get your hopes high on a large context.

You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel.

For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes.

This isn't really anything new; I've been doing something like this for quite a while, I just haven't bothered writing a paper. (: Probably anyone who would seriously tackle the problem of "how do I train a huge model on a tiny amount of VRAM?" would come up with something similar.

However, most people in the field don't, because the actual practical utility of training huge models on a single GPU is quite low. (e.g they got 341 tok/s for a 14B model on a single 3090 while with my method I was getting ~1k tok/s on a single 4090; that's still very slow)

Also, there are more tricks one can use to speed up training/lower VRAM usage which they're not using. For example, you don't need any gradient offloading (you can just accumulate the gradients directly into the optimizers' states if you modify your optimizer), you can use Muon instead of Adam (which needs only half of VRAM of Adam), you can use quantization (both for parameters and for the optimizer states; e.g. I found Muon quantized into 4-bit working relatively well), etc.