作为大模型时代的研究者,尽管我们做的是一些下游任务,我认为也有必要了解一些更加底层的基础知识。

基于此出发点,本文将简单讨论如下主题:

  1. LLM训练时的显存占用分析。
  2. KV Cache
  3. LLM推理时的显存占用分析。

LLM训练时的显存占用分析

在分析显存占用前,需要明确数据的存储精度:float32占用4字节,float16占用2字节,int8占用1字节,int4占用半个字节。

大模型训练时的显存占用主要由以下几个部分组成:模型参数优化器激活值梯度值。除此之外,还有激活值、训练数据、缓存和驱动。

我们先分析主要的三大头。

模型参数:以1B为例,1B的模型参数量是1e9,如果使用全精度存储的fp32,其每个参数的存储为4字节,所以总共的显存占用量:1B \times 4byte = 10^9 \times 4 / 1024 / 1024/1024 = 4GB

梯度值:和模型参数同一个量级。

优化器:优化器是训练中显存占用的大部头。首先需要明确的是,优化器对梯度的存储是使用fp32的。大模型的训练往往采用AdamW,对于一个Adam优化器来说,其主要存储了三种数值:梯度指数平滑值梯度平方指数平滑值参数值。他们分别约等于模型参数的显存占用量。这里就是:4GB+4GB+4GB=12GB

激活值:前向传播的时候会通过激活函数,产生激活值,这个值反向传播的时候也会用到。激活值与模型的层数有关,计算比较复杂。

例如,对于Llama7B来说,有如下参数已确定(seq_len=2048, head=32, batch_size=1, hidden_size=4096, transformer_layer=32),那么需要的存储量为:4byte \times seq_len * batch \times hidden \times layer \times (34+ seq \times head / hidden) = 50GB

上面的公式中,34+s*a/h是一个经过分析后得出的结论,具体可以见这篇文章:Reducing Activation Recomputation in Large Transformer Models

此外还有一些其他的显存占用:

输入输出:相比于模型参数那些非常小,可以忽略不计。

缓存和驱动:实际上还需要CUDA Kernels一般占用1G的显存。

案例分析:llama13b训练

下面我们以llama-13b为例,分析训练时的显存占用。假设采用混合精度训练,前向用FP16进行,batch=1,seq=1024,hidden=5120

混合精度训练指代的是单精度 float和半精度 float16 混合。这篇文章对混合精度训练进行了比较深入的讲解浅谈混合精度训练

  1. 输入输出2 \times b \times seq \times hidden \times 2byte / 1024 / 1024 = 20M
  2. 模型参数:以fp16存储,前面我们知道1B用fp16大致为2GB,所以13B为26GB。
  3. 优化器:使用fp32训练(原因要参考混合精度训练的资料)。梯度指数平滑值:13 \times 4 = 52GB,梯度平方指数平滑值也是52GB,模型参数52GB,一共156GB。
  4. 激活值:参考上面的公式,batch=1时大致为14.5GB
  5. 梯度值:26GB

合计一共有222.5GB,恐怖如斯。

LoRA节约了哪部分显存

LoRA方法通过减少优化器的参数量实现了节约显存。

在LoRA方法中,模型的主体部分是冻结的,只有低秩矩阵部分的参数需要调整,此时优化器仅存储lora矩阵的参数。

而在LoRA矩阵的秩r远小于Transformer层的输入和输出维度大小 dmodel 时,所优化的参数量仅为原始模型的不到1%。

还是以刚才的13b举例子。主要的显存占用就是原来模型的参数26B,至于Lora新增的参数一定比这小很多,因为优化器只需要对新增参数做优化,所以整体下来显存省了很多。

推理中的KV Cache

GPT系列模型做推理的时候是一个自回归的过程,即token是“一个一个吐出来的过程”。

我们回顾一下attention的计算过程,每个X需要通过Linear层产生Q、K、V,然后通过矩阵运算来完成注意力计算。

由于推理过程中第i+1轮总是在第i轮的基础上新吐出了一个token,我们可以把前面几个token的k、v向量缓存起来,就只需要计算这一轮新的token的k、v向量。

例如,在输出“我|爱|你|中国”的时候,我们在计算到“你”这个token的时候会计算其kv向量,那么在计算“中国”的时候就可以直接用上之前计算的kv向量,不用再算一遍。

引入KVCache后,虽然每次推理过程输入的token数不断在增加,但由于每一轮只需读取历史Cache处理当前的token,从而推理过程的flops基本上是恒定的。

LLM推理的显存开销和时间计算

大模型的推理可以分为两个阶段:预填充、解码。

预填充:模型开始生成文本之前,先计算输入的embedding并送入模型。预填充的速度记为TTFT(time to first token)

解码阶段:模型根据输入上下文和所有先前的token逐步生成后续token。解码速度影响每个输出的token时间TPOT(time to output token)

显存占用主要是来源于模型参数KV Cache

下面我们采用llama-7B来进行分析,并约定这些符号表示:s:序列长度,b:批量大小,h:隐藏维度,L:Transformer层数,N:模型参数,GPU Flops速率(A100大概是312e12 flops /s),GPU HBM速率(A100大概是1.5TB/s)。

那么根据上面对1B模型的占用分析,假设使用半精度,此时的模型参数占用大约为14GB。

计算量估计

假设使用float16,计算量 = 预填充计算 + 解码填充。

预填充计算 = 2Nbs=2 \times 7B \times 1024 = 14336B解码计算=2Nb=2 \times 7B = 14B(解释:2表示每个参数大概会算两遍,一个加法一个乘法。之所以预填充乘了s,解码没有,是因为解码阶段一次生成1个token)

内存评估:预填充显存 = 解码显存 = 2N

时间计算

TTFT = \frac{预填充计算量}{GPU flops速率} + \frac{预填充显存}{GPU HBM速率} = \frac{14381B}{312e12 \ flops/s}+\frac{14GB}{1.515TB/s} = 46ms + 9.3ms = 55.3ms

TPOT = \frac{解码计算量}{GPU flops速率} + \frac{解码内存}{GPU HBM速率}=0.045ms+9.3ms = 9.3ms

显存计算=模型参数+KV缓存=14GB+(2 \times h \times L \times batch \times seq) \times 2byte = 14GB + (2 \times 4096 \times 32 \times 1 \times 1024) \times 2 = 14GB+512M

参考