大型语言模型训练与服务的底层数学原理

#Tech

大型语言模型训练与服务的底层数学原理

Reiner Pope(MatX CEO)在一次黑板讲解中,深入剖析了大型语言模型(LLM)的训练和服务的底层数学原理。

他通过分析计算性能和内存带宽等关键因素,揭示了批量大小(batch size)对延迟和成本的影响,解释了为什么可以通过支付更高的费用获得更快的响应速度。

此外,他还讨论了模型架构、模型并行等技术细节,并从API定价等方面推断了长期上下文内存的成本。

此次讲解还涉及了模型过度训练、神经元网络与密码学的演进等主题,为理解人工智能技术的运作机制提供了有价值的视角。

查看原文开头(英文 · 仅前 3 段)

Did a very different format with Reiner Pope - a blackboard lecture where he walks through how frontier LLMs are trained and served.It’s shocking how much you can deduce about what the labs are doing from a handful of equations, public API prices, and some chalk.It’s a bit technical, but I encourage you to hang in there – it’s really worth it.There are less than a handful of people in the world who understand the full stack of AI, from chip design to model architecture, as well as Reiner. It was a real delight to learn from him.Wrote up some flashcards and practice problems to help myself retain what Reiner taught. Hope it’s helpful to you too!Reiner is CEO of MatX, a new chip startup (full disclosure - I’m an angel investor). He was previously at Google, where he worked on software efficiency, compilers, and TPU architecture. Highly recommend the scaling book he coauthored for further study.Watch this one on YouTube so you can see the chalkboard.Jane Street needs constant access to incredibly low-latency compute. I recently asked one of their engineers, Clark, to talk me through how they meet these demands. Our conversation—which touched on everything from FPGAs to liquid cooling—was extremely helpful as I prepped to interview Reiner. You can watch the full discussion and explore Jane Street’s open roles at janestreet.com/dwarkeshGoogle’s Gemma 4 is the first open model that’s let me shut off the internet and create a fully disconnected “focus machine”. This is because Gemma is small enough to run on my laptop, but powerful enough to actually be useful. So, to prep for this interview, I downloaded Reiner’s scaling book, disconnected from wifi, and used Gemma to help me break down the material. Check it out at goo.gle/Gemma4Cursor helped me turn some notes I took on how gradients flow during large-scale pretraining into a great animation. At first, I wasn’t sure the best way to visualize the concept, but Cursor’s Composer 2 Fast model let me iterate on different ideas almost instantaneously. You can check out the animation in my recent blog post. And if you have something to visualize yourself, go to cursor.com/dwarkesh(00:00:00) – How batch size affects token cost and speed(00:32:09) – How MoE models are laid out across GPU racks(00:47:12) – How pipeline parallelism spreads model layers across racks(01:03:37) – Why Ilya said, “As we now know, pipelining is not wise.”(01:18:59) – Because of RL, models may be 100x over-trained beyond Chinchilla-optimal(01:33:02) – Deducing long context memory costs from API pricing(02:04:02) – Convergent evolution between neural nets and cryptographyDwarkesh PatelToday, I’m interviewing Reiner Pope, who is the CEO of MatX, which is a new chip startup. Previously, he was doing TPU architecture and many other things at Google. This is a very different format from my usual interviews. This is going to be a blackboard lecture. We’re going to get up in a second. We in fact built this whole new studio with specifically this format in mind, so it’s a pleasure to get to inaugurate it with you.We’re going to be talking about model architecture, ML infra, and many other things. The reason I think it’s an important topic is because once you understand how training and inference work in a cluster, a lot of things—about why AI is the way it is, why AI architectures are the way they are, why API prices are the way they are, and fundamentally why AI progress is the way it is—start making sense. You need to understand the details to get there, and you need a blackboard to understand the details. Reiner, thank you so much for doing this.Reiner PopeVery happy to be here.Dwarkesh PatelFull disclosure, I am an angel investor in MatX, but that’s unrelated to this podcast. Reiner, to kick us off I’ll ask this question. We have a couple of companies like Claude and Codex and Cursor offering something like Fast Mode, where for 6x the price, they’ll stream you tokens at 2.5x the speed. Mechanically, I’m curious what’s going on here. Why is it the case that you can pay more to get faster latency?Two, could you keep going? Could you pay 100x more and somehow get much faster speeds? Three, could you go the other way? Could you have something like Claude Code “Slow Mode”, where if you are willing to wait for minutes on end, you could get even cheaper prices? Maybe this will help motivate the analysis that you’ll be doing through the lecture.Reiner PopeGreat. To jump to the conclusion a little bit, the big effect is batch size. What we’re going to do now is quantify exactly what that looks like and what its implications are on latency and cost. There’s another effect, which you can call speculative decoding or multi-token prediction. We can maybe come back to that later, but the first thing that we’ll talk through is batch size.What I’d like to introduce is the two principles of analysis. First, we’re going to look at a roofline analysis of how we run a transformer model on a cluster of chips. We’ll take a Blackwell NVL72 cluster, so a rack of 72 GPUs. The roofline analysis means we look at memory bandwidth and compute performance. The other side of that is that we’re going to look at just two simple factors of the model: the time to operate on the weights, and the time to operate on the context, the KV cache.Let’s jump in. We’re going to try and estimate the time that it takes to run an inference of a certain shape. We’re not perfect here. We can’t exactly predict the time, so instead we’re going to approximate. We’re going to say that the time must be greater than or equal to a certain quantity. We’re going to consider two different aspects: the time it takes to do the memory fetches, and the time it takes to do the compute. It will turn out that this gives us very strong predictive power, even with a simple model.One by one, what is the time that it takes to do the compute? There are really two things I need to do in the compute. I need to multiply by all of the active parameters, and then I need to do some work on the attention. Multiplying by all the active parameters, I have a certain batch size that I’m running, and I’ve got a number of active parameters in my model. Then I’m just going to divide this by the compute throughput, which is the FLOPs of the chip. This is a hardware concern.This accounts for all of the compute time for all of the weight matrix multiplies. There’s a little caveat here. We’ve ignored the time to do any of the attention computation, but that in general will be quite small in comparison to this. So we’ll ignore this.Dwarkesh PatelI’ll just interrupt from time to time to ask some very naive questions or to clarify some basic points. For the audience, you’re not serving one user at a time. The batch refers to the fact that you’re serving many different users at the same time, and that’s a whole batch.Reiner PopeI can motivate the batch at least a little bit. We will see exactly why batch is such a favorable optimization. What will turn out to be the case is that if you do not batch together many users, the cost and the economics you get can be a thousand times worse than if you do batch many users together. We’ll be able to see that quite explicitly.Then, number of active parameters. If I look at, for example, a DeepSeek model, the DeepSeek V3 model has about 37 billion active parameters, and 700 billion total parameters. We’re focusing on just the ones that are active for a single AI token.We’re modeling compute performance. I’m going to keep writing equals, but in all of these cases, you can think of this time as being at least this much, and maybe there will be some terms we ignored.On the memory side, what do we need to do with memory? We need to fetch all of the weights, so there is some time to fetch the total number of parameters, not just the active parameters. There’s weight fetch time, and then in addition, there’s a KV cache fetch time. This actually depends on batch size. For every element of the batch, we have to fetch an entire context length worth of tokens, and there’s a size per token, bytes for one token. This is a model parameter.Dwarkesh PatelMaybe just backing up, let’s explain what the KV cache is real quick.Reiner PopeWhen I do a forward pass… Let me draw how the autoregressive inference works. This is during decode. If I have a bunch of text tokens… I’m drawing a tensor because ultimately the tokens are represented as a tensor in some embedding dimension. In this direction, I have the sequence length.The work of running a decode is that I have to run each token through a whole bunch of matrix multiplies over a bunch of different layers. In general, I’m going to have to do that work over all of these tokens. But one step of decode is to produce just this one additional token up here.What I’m going to do there is run a full forward pass of multiplying by all of the weight matrices in the entire model. But then I’ve got this attention mechanism where this token is looking at all of the past tokens, and what is it looking at specifically? It is looking at some internal representation that the model has produced of the tokens, and we call that the KV cache. This process of this single token attending to all of the history of tokens is attention. It is mostly dominated by memory fetches rather than matrix multiplies.So we’ve got the amount of memory that we’re fetching shown over here, and then this is of course just divided by the memory bandwidth, so the memory bytes per second. In fact, these equations here are enough for us to now draw some fit lines. The things that we’d like to look at are sensitivity to batch, and then also, which we’ll draw separately, to context length. We said that the big effect you can get is some trade-off in latency versus cost in batch size.Let’s draw them out. I think there are just really two graphs that we want to draw. We’ll first draw batch size versus time here. When we look at the shape of this, we’ve got a maximum of the sum and then another term. Let’s look at these terms one by one and how they scale: the time for compute and memory, and how they show up.Let’s first look at this compute time. This is just purely linear in batch size with no offset, so it is some curve like this. This is t compute. On the memory side, we’ve got some portion here that is just this constant in some base offset here, which is the weight fetch. Finally, we have this term here, which is the KV fetch, which is pretty linear in batch size, and so it looks like that. The sum of this plus this maxed with this… Let’s at least first draw the sum. The two memory times in conjunction end up looking on this curved slope like this. Then the overall maximum is—I’ll draw a little thicker here—the maximum of these two curves.What does this mean? This is a latency plot. If I grow my batch size, initially I get some not very strong dependence on batch size, so there is some lower bound on latency here. This already partially answers the question. For a given hardware configuration—and we can talk about varying the hardware configuration—there is a lower bound on latency. It is simply that I need to read all of my total parameters from memory into the chips, and that takes a certain amount of time. If I use all of my memory bandwidth, I can’t do any better than that.Dwarkesh PatelIt seems like the way you’ve drawn the slopes for compute time and how the KV grows—and what implication the KV has on memory time—Reiner PopeWhat if this were above or below?Dwarkesh PatelYeah, is that necessarily the case? If this is always true, then as batch size grows compute always dominates KV, which suggests that if you have a big enough batch size, maybe memory is never an issue.Reiner PopeThis is really sensitive to the context length, so I think we should come back and explore this. As you vary the context length, the KV fetch time will go up and up, and that will cause a transition from compute-limited to memory-limited.Dwarkesh PatelIs there something especially significant about the slope being exactly the slope of the compute time?Reiner PopeWhenever we have balance points, it says that you’re getting it exactly right. For the particular context length where the slopes match, that says I am equally memory-bound and compute-bound, which is a really desirable place to be.Dwarkesh PatelThis is a very simple algebra problem, but suppose the optimal is 100K context length, and you go to 200K context length. Does your MFU go down to 50%? Does it have a humongous impact on MFU to be slightly outside of the optimal context length range, the Goldilocks zone?Reiner PopeThat’s right. That is true as modeled here. There is a key point here that I’m modeling the memory fetch as linear in context length. That depends on model architecture. It is true for all of the model architectures with dense attention. Sparse attention actually scales much better than that.Dwarkesh PatelGot it. Is sparse attention what everybody uses in practice?Reiner PopeI’m pretty excited about sparse attention. It’s hard to know what the labs are using. DeepSeek has published a sparse attention mechanism. I’ll just put a plug in that some of the DeepSeek papers that have published sparse attention end up putting a square root in this term.So far, we’ve looked at the latency. It’s hard to read off cost from this. If I think about what cost means… To run this inference, I’m going to use the GPU for a certain number of seconds, like one millisecond or 20 milliseconds. I have to pay the rental time for that time. So it’s $2/hour per GPU or something like that.That’s the cost of this inference, but how many tokens have I processed during that inference? That is the batch size. What we actually want to plot is the cost versus batch size, which is t over B versus batch size. This is the cost per token. We have to imagine dividing each of these three curves by B, so multiplying by this reciprocal. What we end up with there is… The compute curve was linear. We divide by B, and that makes it a constant here. This is t compute. The KV fetch was linear, and now it becomes a constant as well. Then the weight fetch was constant, and now we’ve divided by B, so it becomes this hyperbola.Again, we’re going to compute the max of the sum. The sum of these two terms shifts the hyperbola up. The sum of the KV fetch and the weight fetch gives us a higher hyperbola that’s like this. Then we’re going to take the max with the compute here. We end up with this being the overall shape that we care about.Again, we see some limiting behavior. The cost initially starts very high at a batch size of one. It almost goes to infinity because we’ve got so many weight fetches that are not amortized over a large batch size. But as we increase the batch size, the weight fetches become amortized over so many different batch elements that their cost grows very small, and eventually the compute time ends up driving the cost. So there is a limiting lower bound on cost, which is this line here.Dwarkesh PatelSo Claude Code Slow or Codex Slow or whatever would just live on this line. It wouldn’t help much because you’re not able to amortize the KV values over a much bigger batch.Reiner PopeThey’re unique per batch. The compute is also unique per batch. So what is the minimum work you can do per batch after amortizing everything else away?Dwarkesh PatelThis point where you are no longer memory bandwidth bound, practically how big a batch do you need? How big are the batches practically for frontier models?Reiner PopeYou can just solve for that. It’s not even particularly sensitive to model architecture. Let’s go ahead and do that.What we’re talking about is when the memory time is equal to the compute time. That’s what that question is. Because we’re focused on what the batch size is—and really there’s a question of when the weights are amortized over the multiplies—I’m going to focus on comparing the weight fetch time to the weight multiply time. I’m going to disregard the KV fetch term just to simplify the analysis so we can get a clean answer out. We’re going to equate this portion with these two times.Writing that out, we get N, number of total parameters, over memory bandwidth, is equal to batch size times number of active parameters divided by the compute performance. Looking over here, everything on the top are model parameters. Everything on the bottom are hardware parameters. It turns out to be nice to rearrange them such that we have the hardware parameters on one side.This is equivalent to FLOPs over memory bandwidth being equal to batch size times number of active parameters, divided by the number of total parameters. This hardware parameter ends up being a dimensionless constant. If you look in terms of FLOPs… What are the dimensions of this? This is multiplies per second. This is bytes per second. So that’s not quite dimensionless. But what you do is you say, how many FP4 multiplies per second times the fact that each FP4 is half a byte. I can actually make this end up being dimensionless. On most GPUs, this ends up being somewhere around 300.Dwarkesh PatelHas that ratio changed over time as we’ve gone from model generation to model generation, where the FLOPs keep increasing?Reiner PopeThis is a hardware parameter. To what extent has the hardware changed? From A100 to H100 to B100, the FLOPs have increased substantially, the memory bandwidth has also increased substantially, and it has remained reasonably stable.We can express this one as well. This is a sparsity parameter. I might even phrase this slightly differently. Let’s solve for batch size in total. Moving this back over to the other side, we end up with batch size needs to be bigger than approximately 300 times sparsity. For example, in DeepSeek I activate 32 out of 256 experts, so this would be 8 for DeepSeek.This actually gives you a ballpark which is remarkably accurate to practice. Generally, people will go a little bit larger than this. They don’t really want to be exactly at the balance point because real-world efficiencies aren’t as good as a roofline analysis would say. But take this and maybe double or triple it.Dwarkesh PatelOkay, so it’s two to three thousand tokens per batch. But then if you included the KV cache, the implication would be that the optimal batch size...Reiner PopeShould grow larger. We solved for the equivalence between when compute time is equal to memory time. If I add in something that consumes more memory bandwidth, then I have less available for the weight loads. I need to grow the memory bandwidth more, and therefore the batch size more.Dwarkesh PatelThis seems incredibly small. This would be less than one sequence, right?Reiner PopeKeep in mind that I’m talking about the number of tokens that I’m generating one more token for. It’s actually 2,000 unique sequences.Dwarkesh PatelGot it. We’re just talking about a single forward pass on these sequences. You think of the batch as the number of sequences.Reiner PopeThat’s right.Dwarkesh PatelIf you’ve got a frontier model and you are actually doing inference, surely they must have more than 2,000 concurrent users. Is there any added latency from the fact that you need to have the whole batch fill up? Or if you have a reasonable amount of users, is it so unlikely that it would take you 100 milliseconds to fill up the next 2,000 slots?Reiner PopeThe way to think about this is: when does the train depart, as a model? Let’s say I’ve picked a batch size that I’m going to run at. By the way, this intersection point is the same intersection point here. I pick this batch size, and I know that it’s going to take, for example, 20 milliseconds, which is a common place this ends up landing.This is a timeline of what is running on the GPU. It’s going to start a new batch every 20 milliseconds regardless. You can think of this as a schedule for the train. A new train departs every 20 milliseconds. Any passengers who

※ 出于版权考虑,仅引用前 3 段。完整内容请阅读原文。

阅读原文 ↗