大模型显存占用分析:训练与推理
作为大模型时代的研究者,尽管我们做的是一些下游任务,我认为也有必要了解一些更加底层的基础知识。
基于此出发点,本文将简单讨论如下主题:
- LLM训练时的显存占用分析。
- KV Cache
- LLM推理时的显存占用分析。
LLM训练时的显存占用分析
在分析显存占用前,需要明确数据的存储精度:float32占用4字节,float16占用2字节,int8占用1字节,int4占用半个字节。
大模型训练时的显存占用主要由以下几个部分组成:模型参数、优化器、激活值、梯度值。除此之外,还有激活值、训练数据、缓存和驱动。
我们先分析主要的三大头。
模型参数:以1B为例,1B的模型参数量是1e9,如果使用全精度存储的fp32,其每个参数的存储为4字节,所以总共的显存占用量:1B \times 4byte = 10^9 \times 4 / 1024 / 1024/1024 = 4GB
梯度值:和模型参数同一个量级。
优化器:优化器是训练中显存占用的大部头。首先需要明确的是,优化器对梯度的存储是使用fp32的。大模型的训练往往采用AdamW,对于一个Adam优化器来说,其主要存储了三种数值:梯度指数平滑值、梯度平方指数平滑值、参数值。他们分别约等于模型参数的显存占用量。这里就是:4GB+4GB+4GB=12GB
激活值:前向传播的时候会通过激活函数,产生激活值,这个值反向传播的时候也会用到。激活值与模型的层数有关,计算比较复杂。
例如,对于Llama7B来说,有如下参数已确定(seq_len=2048, head=32, batch_size=1, hidden_size=4096, transformer_layer=32),那么需要的存储量为:4byte \times seq_len * batch \times hidden \times layer \times (34+ seq \times head / hidden) = 50GB
上面的公式中,34+s*a/h
是一个经过分析后得出的结论,具体可以见这篇文章:Reducing Activation Recomputation in Large Transformer Models
此外还有一些其他的显存占用:
输入输出:相比于模型参数那些非常小,可以忽略不计。
缓存和驱动:实际上还需要CUDA Kernels一般占用1G的显存。
案例分析:llama13b训练
下面我们以llama-13b为例,分析训练时的显存占用。假设采用混合精度训练,前向用FP16进行,batch=1,seq=1024,hidden=5120
混合精度训练指代的是单精度 float和半精度 float16 混合。这篇文章对混合精度训练进行了比较深入的讲解浅谈混合精度训练。
- 输入输出:2 \times b \times seq \times hidden \times 2byte / 1024 / 1024 = 20M
- 模型参数:以fp16存储,前面我们知道1B用fp16大致为2GB,所以13B为26GB。
- 优化器:使用fp32训练(原因要参考混合精度训练的资料)。梯度指数平滑值:13 \times 4 = 52GB,梯度平方指数平滑值也是52GB,模型参数52GB,一共156GB。
- 激活值:参考上面的公式,batch=1时大致为14.5GB
- 梯度值:26GB
合计一共有222.5GB,恐怖如斯。
LoRA节约了哪部分显存
LoRA方法通过减少优化器的参数量实现了节约显存。
在LoRA方法中,模型的主体部分是冻结的,只有低秩矩阵部分的参数需要调整,此时优化器仅存储lora矩阵的参数。
而在LoRA矩阵的秩r远小于Transformer层的输入和输出维度大小 dmodel 时,所优化的参数量仅为原始模型的不到1%。
还是以刚才的13b举例子。主要的显存占用就是原来模型的参数26B,至于Lora新增的参数一定比这小很多,因为优化器只需要对新增参数做优化,所以整体下来显存省了很多。
推理中的KV Cache
GPT系列模型做推理的时候是一个自回归的过程,即token是“一个一个吐出来的过程”。
我们回顾一下attention的计算过程,每个X需要通过Linear层产生Q、K、V,然后通过矩阵运算来完成注意力计算。
由于推理过程中第i+1轮总是在第i轮的基础上新吐出了一个token,我们可以把前面几个token的k、v向量缓存起来,就只需要计算这一轮新的token的k、v向量。
例如,在输出“我|爱|你|中国”的时候,我们在计算到“你”这个token的时候会计算其kv向量,那么在计算“中国”的时候就可以直接用上之前计算的kv向量,不用再算一遍。
引入KVCache后,虽然每次推理过程输入的token数不断在增加,但由于每一轮只需读取历史Cache处理当前的token,从而推理过程的flops基本上是恒定的。
LLM推理的显存开销和时间计算
大模型的推理可以分为两个阶段:预填充、解码。
预填充:模型开始生成文本之前,先计算输入的embedding并送入模型。预填充的速度记为TTFT(time to first token)
解码阶段:模型根据输入上下文和所有先前的token逐步生成后续token。解码速度影响每个输出的token时间TPOT(time to output token)
显存占用主要是来源于模型参数和KV Cache。
下面我们采用llama-7B来进行分析,并约定这些符号表示:s:序列长度,b:批量大小,h:隐藏维度,L:Transformer层数,N:模型参数,GPU Flops速率(A100大概是312e12 flops /s),GPU HBM速率(A100大概是1.5TB/s)。
那么根据上面对1B模型的占用分析,假设使用半精度,此时的模型参数占用大约为14GB。
计算量估计:
假设使用float16,计算量 = 预填充计算 + 解码填充。
预填充计算 = 2Nbs=2 \times 7B \times 1024 = 14336B,解码计算=2Nb=2 \times 7B = 14B(解释:2表示每个参数大概会算两遍,一个加法一个乘法。之所以预填充乘了s,解码没有,是因为解码阶段一次生成1个token)
内存评估:预填充显存 = 解码显存 = 2N
时间计算:
TTFT = \frac{预填充计算量}{GPU flops速率} + \frac{预填充显存}{GPU HBM速率} = \frac{14381B}{312e12 \ flops/s}+\frac{14GB}{1.515TB/s} = 46ms + 9.3ms = 55.3ms
TPOT = \frac{解码计算量}{GPU flops速率} + \frac{解码内存}{GPU HBM速率}=0.045ms+9.3ms = 9.3ms
显存计算:=模型参数+KV缓存=14GB+(2 \times h \times L \times batch \times seq) \times 2byte = 14GB + (2 \times 4096 \times 32 \times 1 \times 1024) \times 2 = 14GB+512M
参考
- 感谢你赐予我前进的力量