1.告诉你一个模型的参数量,你要怎么估算出训练和推理时的显存占用?
2.Lora相比于全参训练节省的显存是哪一部分?Qlora相比Lora呢?
3.混合精度训练的具体流程是怎么样的?
这是我曾在面试中被问到的问题,为了巩固相关的知识,打算系统的写一篇文章,帮助自己复习备战秋招的同时,希望也能帮到各位小伙伴。这篇文章将围绕大模型在单卡训练或推理时的显存占用进行系统学习分析,其中有的知识点可能不会涉及太过深入点到为止(因为我也不会),但尽量保证整个读下来逻辑通畅,通俗易懂(只有小白最懂小白!)。
1.数据精度
想要计算显存,从“原子”层面来看,就需要知道我们的使用数据的精度,因为精度代表了数据存储的方式,决定了一个数据占多少bit。 我们都知道:
1 byte = 8 bits
1 KB = 1,024 bytes
1 MB = 1,024 KB
1 GB = 1,024 MB
由此可以明白,一个含有1G参数的模型,如果每一个参数都是32bit(4byte),那么直接加载模型就会占用4x1G的显存。
1.1常见的几种精度类型
个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:
各种精度的数据结构
可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa) 三部分组成。 符号位都是1位(0表示正,1表示负),指数位影响浮点数范围,小数位影响精度。 其中TF32并不是有32bit,只有19bit不要记错了。BF16指的是Brain Float 16,由Google Brain团队提出。
1.2 具体计算例子
我硕士话,讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据: 我以BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:
- 题目:
随机生成的BF16精度数据
- 先给出具体计算公式:
(−1)Sign⋅2Exponent−127⋅(1+Mantissa128)
- 然后step by step地分析(不是,怎么还对自己使用上Cot了)
符号位Sign = 1,代表是负数
指数位Exponent = 17,中间一坨是 2−110
小数位Mantissa = 3,后面那一坨是 1+3128
- 最终结果
三个部分乘起来就是最终结果 -8.004646331359449e-34
- 注意事项
中间唯一需要注意的地方就是指数位是的全0和全1状态是特殊情况,不能用公式,如果想要深入了解可以看这个博客: 彻底搞懂float16与float32的计算方式-CSDN博客 如果感兴趣想更加深入了解如何从FP32转换为BF16的,可以看这个博主的讲解: 从一次面试搞懂 FP16、BF16、TF32、FP32
2.全参训练和推理的显存分析
OK了我们知道了数据精度对应存储的方式和大小, 相当于我们了解了工厂里不同规格的机器零件,但我们还需要了解整个生产线的运作流程,我们才能准确估算出整个工厂(也就是我们的模型训练过程)在运行时所需的资源(显存)。
那么就以目前最常见的混合精度训练方法作为参考,来看一看显存都去哪了。
2.1混合精度训练
2.1.1 原理介绍
顾名思义,混合精度训练就是将多种不同的精度数据混合在一起训练,《 MIXED PRECISION TRAINING 》这篇论文里将FP16和FP32混合,优化器用的是Adam,如下图所示:
MIXED PRECISION TRAINING论文里的训练流程图
按照训练运行的逻辑来讲:
Step1:优化器会先备份一份FP32精度的模型权重,初始化好FP32精度的一阶和二阶动量(用于更新权重)。
Step2:开辟一块新的存储空间,将FP32精度的模型权重转换为FP16精度的模型权重。
Step3:运行forward和backward,产生的梯度和激活值都用FP16精度存储。
Step4:优化器利用FP16的梯度和FP32精度的一阶和二阶动量去更新备份的FP32的模型权重。
Step5:重复Step2到Step4训练,直到模型收敛。
我们可以看到训练过程中显存主要被用在四个模块上:
- 模型权重本身(FP32+FP16)
- 梯度(FP16)
- 优化器(FP32)
- 激活值(FP16)
2.1.2 三个小问题
写到这里,我就有3个小问题,第一个问题,为什么不全都用FP16,那不是计算更快、内存更少?
根据我们第一章的知识,我们可以知道FP16精度的范围比FP32窄了很多,这就会产生数据溢出和舍入误差两个问题(想深入了解的,请看全网最全-混合精度训练原理),这会导致梯度消失无法训练,所以我们不能全都用FP16,还需要FP32来进行精度保证。看到这里你也许会想到可以用BF16代替,是的,这也是为什么如今很多训练都是BF16的原因,至少BF16不会产生数据溢出了,业界的实际使用也反馈出比起精度,大模型更在意范围。
第二个问题,为什么我们只对激活值和梯度进行了半精度优化,却新添加了一个FP32精度的模型副本,这样子显存不会更大吗?
答案是不会 , 激活值和batch_size以及seq_length相关,实际训练的时候激活值对显存的占用会很大,对于激活值的正向优化大于备份模型参数的负向优化,最终的显存是减少的。(这里还可以考虑梯度检查点的优化方法,能更进一步优化激活值的显存,感兴趣可以看看这个大模型高效训练基础知识:梯度检查点(Gradient Checkpointing))。
第三个问题,我们知道显存和内存一样,有静态和动态之分别,那么上面提到的哪些是静态哪些是动态呢?
应该很多人都能猜到:
- 静态:优化器状态、模型参数
- 动态:激活值、梯度值
也就是说,我们其实没法特别准确的计算出我们实际运行时候的显存大小,如果在面试的时候,就可以忽略掉激活值的计算,梯度当做静态计算就好。 如果想要深度探索,指路[LLM]大模型显存计算公式与优化
动态监控显存图
2.1.3 来个测试吧!
写到这里,我们应该对于分析大模型训练时候的显存问题应该不在话下了(除了动态部分),那么我们就来实测一下,正在阅读的小伙伴也可以先自己尝试计算一下,看看是不是真的懂了。 对于llama3.1 8B模型,FP32和BF16混合精度训练,用的是AdamW优化器,请问模型训练时占用显存大概为多少?
解:
模型参数:16(BF16) + 32(PF32)= 48G
梯度参数:16(BF16)= 16G
优化器参数:32(PF32) + 32(PF32)= 64G
不考虑激活值的情况下,总显存大约占用 (48 + 16 + 64) = 128G
2.2 推理与KV Cache
2.2.1 原理理解
推理的时候,显存几乎只考虑模型参数本身,除此之外就是现在广泛使用的KV cache也会占用显存。KV cache与之前讲的如何减少显存不一样,KV cache的目的是减少延迟,也就是为了推理的速度牺牲显存。
具体KV cache是什么我就不展开讲了,我贴一张动图就可以非常清晰地明白了(如果还不明白可以去看大模型推理加速:看图学KV Cache),记住一点,我们推理就是在不断重复地做”生成下一个token“的任务,生成当前token 仅仅与当前的QKV和之前所有KV有关,那么我们就可以去维护这个KV并不断更新。
KV Cache动态实现
顺便回答一个很多小白经常会问的问题,为什么没有Q Cache呢?
因为生成当前的token只依赖当前的Q,那为什么生成当前的token只依赖当前的Q呢,因为Self-Attention的公式决定的,S代表Softmax激活函数:
𝑆(𝑄𝐾⊤)𝑉 = [𝑆1(𝑞1⋅𝑘1)0⋯0𝑆2(𝑞2⋅𝑘1)𝑆2(𝑞2⋅𝑘2)⋯0⋮⋮⋱⋮𝑆𝑇(𝑞𝑇⋅𝑘1)𝑆𝑇(𝑞𝑇⋅𝑘2)⋯𝑆𝑇(𝑞𝑇⋅𝑘𝑇)] [𝑣1𝑣2⋮𝑣𝑇]
= [𝑆1(𝑞1⋅𝑘1)𝑣1𝑆2(𝑞2⋅𝑘1)𝑣1+𝑆2(𝑞2⋅𝑘2)𝑣2⋮𝑆𝑇(𝑞𝑇⋅𝑘1)𝑣1+𝑆𝑇(𝑞𝑇⋅𝑘2)𝑣2+⋯+𝑆𝑇(𝑞𝑇⋅𝑘𝑇)𝑣𝑇]
我们可以看到,在序列t的位置,也就是第t行,只跟 𝑄𝑡 有关系,也就是说,Attention的计算公式就决定了我们不需要保存每一步的Q,再深入地说,矩阵乘法的数学特性决定了我们不需要保存每一步的Q。
2.2.2 计算KV Cache显存
如何计算KV Cache的显存是我这篇文章想要关心的事情,先给出公式:
𝑀𝑒𝑚𝑜𝑟𝑦=𝑏𝑎𝑡𝑐ℎ𝑠𝑖𝑧𝑒×𝑠𝑒𝑞𝑙𝑒𝑛𝑔𝑡ℎ×ℎ𝑖𝑑𝑑𝑒𝑛𝑠𝑖𝑧𝑒×𝑙𝑎𝑦𝑒𝑟𝑠×2×2
前面的四个参数相乘应该很好理解,就是KV对应在模型每一层的所有隐藏向量的总和,第一个2指的是KV两部分,第二个2指的是半精度对应的字节数。
举个栗子,对于llama7B,hiddensize = 4096,seqlength = 2048 , batchsize = 64,layers = 32 计算得到
𝑀𝑒𝑚𝑜𝑟𝑦=64×2048×4096×32×2×2≈68𝐺 ,
可以看到,KV Cache在大批量长句子的情况下,显存占用率也是很大的。
68G看着是相对模型本身很大,但这是在batch很大的情况下,在单batch下,KV Cache就仅占有 1G左右的显存了,就仅仅占用模型参数一半的显存。
2.2.3 MQA和GQA
什么,你觉得KV Cache用的显存还是太多了,不错,对于推理落地侧,再怎么严苛要求也是合理的,MQA和GQA就是被用来进一步减少显存的方法,现在的大模型也几乎都用到了这个方法,我们就来讲一讲。
三种KV处理方式
其实方法不难理解,看这张图一目了然,关键词就是“共享多头KV”,很朴素的删除模型冗余结构的思路。最左侧就是最基础的MHA多头自注意力,中间的GQA就是保留几组KV头,右侧MQA就是只保留1组KV头,目前用的比较多的是GQA,降低显存提速的同时也不会太过于影响性能。如果没看懂的小伙伴可以去大模型推理加速:KV Cache 和 GQA详细看看,这里就不展开讲了,我想讲的是具体显存的变化。
上一小节我们知道MHA的KV Cache占用显存的计算公式是 𝑀𝑒𝑚𝑜𝑟𝑦=𝑏𝑎𝑡𝑐ℎ𝑠𝑖𝑧𝑒×𝑠𝑒𝑞𝑙𝑒𝑛𝑔𝑡ℎ×ℎ𝑖𝑑𝑑𝑒𝑛𝑠𝑖𝑧𝑒×𝑙𝑎𝑦𝑒𝑟𝑠×2×2
这里 ℎ𝑖𝑑𝑑𝑒𝑛𝑠𝑖𝑧𝑒 拆为 ℎ𝑖𝑑𝑑𝑒𝑛𝑠𝑖𝑧𝑒=ℎ𝑒𝑎𝑑×ℎ𝑒𝑎𝑑_𝑑𝑖𝑚
原公式改写为
𝑀𝑒𝑚𝑜𝑟𝑦=𝑏𝑎𝑡𝑐ℎ𝑠𝑖𝑧𝑒×𝑠𝑒𝑞𝑙𝑒𝑛𝑔𝑡ℎ×ℎ𝑒𝑎𝑑×ℎ𝑒𝑎𝑑_𝑑𝑖𝑚×𝑙𝑎𝑦𝑒𝑟𝑠×2×2
那么不难理解,MQA占用显存公式则是
𝑀𝑒𝑚𝑜𝑟𝑦=𝑏𝑎𝑡𝑐ℎ𝑠𝑖𝑧𝑒×𝑠𝑒𝑞𝑙𝑒𝑛𝑔𝑡ℎ×1×ℎ𝑒𝑎𝑑𝑑𝑖𝑚×𝑙𝑎𝑦𝑒𝑟𝑠×2×2
GQA占用显存公式则是
𝑀𝑒𝑚𝑜𝑟𝑦=𝑏𝑎𝑡𝑐ℎ𝑠𝑖𝑧𝑒×𝑠𝑒𝑞𝑙𝑒𝑛𝑔𝑡ℎ×𝑔𝑟𝑜𝑢𝑝×ℎ𝑒𝑎𝑑𝑑𝑖𝑚×𝑙𝑎𝑦𝑒𝑟𝑠×2×2
公式中改变的就是共享的头数。
有一个小细节,可以重头开始训练MQA 和 GQA的模型,也可以像 GQA 论文里面一样基于开源模型,修改模型结构后继续预训练。目前基本上都是从头开始训练的,因为要保持训练和推理的模型结构一致。
3.Lora和Qlora显存分析
上面两章详细对全参微调训练和推理进行了显存分析,聪明的小伙伴就发现了一个问题,现在都用PEFT(高效参数微调)了,谁有那么多资源全参训练啊推理阶段也是要量化的,这样又该怎么进行显存分析呢。那么我们这一章就来解决这个问题,我相信完全理解前两章的小伙伴理解起来会非常轻松,所谓的显存分析,只要知道了具体的流程和数据精度,那么分析的方法都是类似的。OK,我们将会在这一章里详细分析目前前业界最火的Lora和Qlora方法的显存占用情况,中间也会涉及到相关的原理知识,冲!
3.1 Lora
能看到这里的人,我想对于Lora的原理应该都很了解了,就浅浅提一下,如下图所示,就是在原来的权重矩阵的旁路新建一对低秩的可训练权重,训练的时候只训练旁路,大大降低了训练的权重数量,参数量从 𝑑∗𝑑 降为 2∗𝑑∗𝑟 。
Lora原理图
有了前面的全参情况下训练的显存分析,现在分析起来就比较通顺了,我们一步一步来,还是以BF16半精度模型Adamw优化器训练为例子,lora部分的参数精度也是BF16,并且设1字节模型参数对应的显存大小 Φ 。
- 首先是模型权重本身的权重,这个肯定是要加载原始模型和lora旁路模型的,因为lora部分占比小于2个数量级,所以显存分析的时候忽略不计,显存占用 2Φ 。
- 然后就是优化器部分,优化器也不需要对原模型进行备份了,因为优化器是针对于需要更新参数的模型权重部分进行处理,也就是说优化器只包含Lora模型权重相关的内容,考虑到数量级太小,也忽略不计,故优化器部分占用显存 0Φ 。
- 其实容易搞错混淆的部分就是梯度的显存了,我看了不少的博客文章,有说原始模型也要参与反向传播,所以是要占用一份梯度显存的,也有的说原始模型都不更新梯度,肯定只需要Lora部分的梯度显存,搞得我头很大。那么究竟正确答案是哪一种呢,这里直接给出答案,不需要计算原始模型部分的梯度,也基本不占用显存。也就是说梯度部分占用显存也可以近似为 0Φ 。想深入探究的可以去大模型高效微调-LoRA原理详解和训练过程深入分析。
总的来说,不考虑激活值的情况下,Lora微调训练的显存占用只有 2Φ ,一个7B的模型Lora训练只需要占用显存大约14G左右。验证一下,我们来看Llama Factory里给出训练任务的显存预估表格:
Llama Factory的表格
可以看到7B模型的Lora训练的显存消耗与我们估计得也差不多,同时也还可以复习一下全参训练、混合精度训练的显存分析,也是基本符合我们之前的分析的。
3.2 QLora
上面Llama Factory的那张表也是稍微剧透了一下我们接下来要讲的内容,也就是QLora,继Lora之后也是在业界落地非常广泛通用的一种大模型PEFT方法。QLora,也叫做量化Lora,顾名思义,也就是进一步压缩模型的精度,然后用Lora训练,他的核心思路很好理解,但实际上涉及的知识点细节却并不少。我同样也不会太过深入地去介绍这个中细节,如想深入了解可以去看论文或者其他博客(指路QLoRA、GPTQ:模型量化概述,QLoRA(Quantized LoRA)详解),我主要是想按照显存占用的思路去分析Qlora,理解思路永远比死的知识点更加重要。
3.2.1 Qlora的整体思路
Qlora来自于《 QLORA: Efficient Finetuning of Quantized LLMs 》这篇论文,实际上这篇论文的核心在于提出了一种新的量化方法,重点在于量化而不是Lora。 很多不了解的人看到量化lora这个名字就以为是对Lora部分的参数进行量化,因为他们认为毕竟只有Lora部分的参数参与了训练,但理解了上面一节的小伙伴就明白实际并不是这样,原始模型的本身参数虽然不更新参数,但是仍然需要前向和反向传播,QLora优化的正是Lora里显存占大头的模型参数本身。
那么Qlora就是把原始模型参数从16bit压缩到4bit,然后更新这个4bit参数吗?非也非也,这里需要区分两个概念,一个是计算参数,一个是存储参数,计算参数就是在前向、反向传播参与实际计算的参数,存储参数就是不参与计算一开始加载的原始参数。QLora的方法就是,加载并且量化16bit的模型原始参数为4bit作为存储参数,但是在具体需要计算的时候,将该部分的4bit参数反量化为16bit作为计算参数。 也就是说,QLora实际上我们训练计算里用到的所有数据的精度都是和Lora一样的,只是加载的模型是4bit,会进行一个反量化到16bit的方法,用完即释放。前面说到的都是模型原始参数本身,不包括lora部分的参数,Lora部分的参数不需要量化,一直都是16bit。 看到这里机智的你应该也想到了,这比Lora多了一个量化反量化的操作,那训练时间是不是会更长,没错一般来讲Qlora训练会比Lora多用30%左右的时间。
3.2.2 Qlora的技术细节
基本的思路讲完了,那么其中包含了哪些具体的实现细节呢? Qlora主要包括三个创新点,这里我只简单提及,应付面试足够的程度,如果想要详细了解可以去看论文:
- NF4量化:常见的量化分布都是基于参数是均匀分布的假设,而这个方法基于参数是正态分布的假设,这样使得量化精度大大提升。
- 双重量化:对于第一次量化后得到的用于计算反量化时的锚点参数,我们对这个锚点参数进行量化,可以进一步降低显存。
- 优化器分页:为了防止OOM,可以在GPU显存紧张的时候利用CPU内存进行加载参数。
3.2.3 显存分析
想必已经理解QLora运行思路的小伙伴,应该可以很轻松的分析出Qlora占用显存的部分了吧,这就是理清楚思路的好处。没错,Qlora占用的显存主要就是4Bit量化后的模型本身也就是 0.5Φ ,这里没有考虑少量的Lora部分的参数和量化计算中可能产生的显存。可以回过头去看看刚才的表格,也是基本符合预期的。
最后我们用一个表格来总结所有之前我们提到的显存分析:
部分显存对应精度(训练) | 全参微调(全FP16) | 全参微调(BF16混合精度) | Lora | QLora |
---|---|---|---|---|
主干模型(模型存储/计算参数) | FP16/FP16 | BF16/BF16 | BF16/BF16 | NF4/BF16 |
主干模型(梯度) | FP16 | BF16 | Null | Null |
主干模型(adamw优化器) | 2 x FP16 | 3 x FP32 | Null | Null |
LoRA部分(可忽略不计) | Null | Null | BF16 | BF16 |
总和(大约) | 8Byte | 16Byte | 2Byte | 0.5Byte |
如何学习AI大模型 ?
“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。
这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。【保证100%免费】🆓
CSDN粉丝独家福利
这份完整版的 AI 大模型学习资料已经上传CSDN,朋友们如果需要可以扫描下方二维码&点击下方CSDN官方认证链接免费领取 【保证100%免费】
读者福利: 👉👉CSDN大礼包:《最新AI大模型学习资源包》免费分享 👈👈
对于0基础小白入门:
如果你是零基础小白,想快速入门大模型是可以考虑的。
一方面是学习时间相对较短,学习内容更全面更集中。
二方面是可以根据这些资料规划好学习计划和方向。
👉1.大模型入门学习思维导图👈
要学习一门新的技术,作为新手一定要先学习成长路线图,方向不对,努力白费。
对于从来没有接触过AI大模型的同学,我们帮你准备了详细的学习成长路线图&学习规划。可以说是最科学最系统的学习路线,大家跟着这个大的方向学习准没问题。(全套教程文末领取哈)
👉2.AGI大模型配套视频👈
很多朋友都不喜欢晦涩的文字,我也为大家准备了视频教程,每个章节都是当前板块的精华浓缩。
👉3.大模型实际应用报告合集👈
这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。(全套教程文末领取哈)
👉4.大模型落地应用案例PPT👈
光学理论是没用的,要学会跟着一起做,要动手实操,才能将自己的所学运用到实际当中去,这时候可以搞点实战案例来学习。(全套教程文末领取哈)
👉5.大模型经典学习电子书👈
随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。(全套教程文末领取哈)
👉6.大模型面试题&答案👈
截至目前大模型已经超过200个,在大模型纵横的时代,不仅大模型技术越来越卷,就连大模型相关的岗位和面试也开始越来越卷了。为了让大家更容易上车大模型算法赛道,我总结了大模型常考的面试题。(全套教程文末领取哈)
👉学会后的收获:👈
• 基于大模型全栈工程实现(前端、后端、产品经理、设计、数据分析等),通过这门课可获得不同能力;
• 能够利用大模型解决相关实际项目需求: 大数据时代,越来越多的企业和机构需要处理海量数据,利用大模型技术可以更好地处理这些数据,提高数据分析和决策的准确性。因此,掌握大模型应用开发技能,可以让程序员更好地应对实际项目需求;
• 基于大模型和企业数据AI应用开发,实现大模型理论、掌握GPU算力、硬件、LangChain开发框架和项目实战技能, 学会Fine-tuning垂直训练大模型(数据准备、数据蒸馏、大模型部署)一站式掌握;
• 能够完成时下热门大模型垂直领域模型训练能力,提高程序员的编码能力: 大模型应用开发需要掌握机器学习算法、深度学习
CSDN粉丝独家福利
这份完整版的 AI 大模型学习资料已经上传CSDN,朋友们如果需要可以扫描下方二维码&点击下方CSDN官方认证链接免费领取 【保证100%免费】
读者福利: 👉👉CSDN大礼包:《最新AI大模型学习资源包》免费分享 👈👈