随着开源大模型雨后春笋一样的发布,怎样根据模型的参数量来计算所需要的显存成了很多小伙伴关心的话题!我们今天就一起来了解下这个知识!
假如有一个块3090 24G的显卡,我比较关心的一定是我最大能跑多少参数的大模型?
目前模型的参数绝大多数都是float32类型, 占用4个字节。所以一个粗略的计算方法就是,每10亿个参数,占用4G显存(实际应该是10^9*4/1024/1024/1024=3.725G),不过我们用4G来计算就可以了。
比如LLaMA的参数量为7000559616,那么全精度加载这个模型参数需要的显存为:
7000559616 * 4 /1024/1024/1024 = 26.08G
如果我们用半精度的FP16/BF16来加载,这样每个参数只占2个字节,所需显存就降为一半,半精度是个不错的选择,显存少了一半,模型效果因为精度的原因会略微降低,但一般在可接受的范围之内。
除了半精度,大模型还有8位精度和4位精度,对显存的需求量分别下降到原来的1/4和1/8,我们用Qwen1.5系列的模型来举例:
Qwen1.5版的模型一共推出了7个不同的参数量,分别是0.5B、1.8B、4B、7B、14B、32B、72B
以下皆是粗略计算!
如果是全精度的话,分别需要的显卡显存是:
2G、7.2G、16G、28G、56G、128G、288G
如果是半精度的话,分别需要的显卡显存是:
1G、3.6G、8G、14G、28G、64G、144G
如果是8位精度的话,分别需要的显卡显存是:
0.5G、1.8G、4G、7G、14G、32G、72G
如果是4位精度的话,分别需要的显卡显存是:
0.25G、0.9G、2G、3.5G、7G、16G、36G
不过上面只是加载模型需要用到的显存量,模型运算时的一些临时变量也需要申请空间,比如你beam search的时候。所以真正做推理的时候记得留一些Buffer,不然就容易OOM。
参考资料:https://blog.csdn.net/weixin_44292902/article/details/133767448
原创文章,作者:朋远方,如若转载,请注明出处:https://caovan.com/zenyanggenjumoxingcanshuliangjisuantuilishixuyaodexiancun/.html