PyTorch显存优化技巧:从基础到进阶

PyTorch显存优化技巧:从基础到进阶

一、小模块API参数inplace设置为True(省一点点)

比如:Relu()有一个默认参数inplace,默认设置为False,当设置为True时,计算时的得到的新值不会占用新的空间而是直接覆盖原来的值,进而可以节省一点点内存。

二、Apex半精度计算(省一半左右)

安装方式

git clone https://github.com/NVIDIA/apex

cd apex

python3 setup.py install

原理:一款Nvidia开发的基于PyTorch的混合精度训练加速神器,可以用短短三行代码(导包、初始化、反向传播)就能实现不同程度的混合精度加速,训练时间直接缩小一半。

from apex import amp

model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”

with amp.scale_loss(loss, optimizer) as scaled_loss:

scaled_loss.backward()

其中只有一个opt_level需要用户自行配置(# 这里是“欧x”,不是“零x”):

O0:纯FP32训练,可以作为accuracy的baseline;O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。O3:纯FP16训练,很不稳定,但是可以作为speed的baseline;

注:apex不能和contiguous-params一起用,否则无法实现模型参数优化。(因为这俩模型参数指针要打架)

三、删无用变量并及时清空显存垃圾(省将近一半)

这个的用法比较简单,就是先del变量,再cuda垃圾回收。

del answers_types_emb_all, answers_types_len

torch.cuda.empty_cache()

但这个的使用方式,不能在整个代码中只使用一次torch.cuda.empty_cache()——否则整体的最大显存不会变化。

正确的使用方法是:应该在整体代码的多个位置,分别del变量后,对应位置之后即使用torch.cuda.empty_cache()。

这样整体最大显存,就会下降(del的变量整体显存之和 - max的一坨del变量的显存)——分批清空显存,才能真正地用torch.cuda.empty_cache()降低显存。

四、重计算方法(省一半以上)

用pytorch的checkpoint断点保存机制,将连续的几个函数计算,分为若干段,保存了再计算——用保存加载的时间,换显存计算的空间

原来一个函数的计算方式:

sequence_enc = self.encoder(sequence_emb, sequence_mask, output_all_encoded_layers=False)[-1]

用checkpoint断点保存的方式:

from torch.utils.checkpoint import checkpoint

sequence_enc = checkpoint(self.encoder, sequence_emb, sequence_mask, torch.tensor([0]))[-1]

值得一提的是,checkpoint函数,要求输入的函数(如self.encoder),返回值必须为torch.tensor类型,如果原来函数(self.encoder)的返回值是一个数组

return all_encoder_layers

用torch.stack改造一下,就可以用checkpoint了

return torch.stack(all_encoder_layers, dim=0)

注:checkpoint后没有梯度记忆,解决方法参考Pytorch使用GPU进行训练注意事项 | 文艺数学君 (mathpretty.com)的方法二

五、效果

楼主同时用了上述四项节省显存的技术,将原始显存大于11GB的程序,降低到了只有2GB-3GB的水平

上述技术节省显存的程度(由多到少):

重计算方法Apex半精度计算删无用变量并及时清空显存垃圾小模块API参数inplace设置为True

上述技术实现的复杂程度(由易到难):

小模块API参数inplace设置为TrueApex半精度计算删无用变量并及时清空显存垃圾重计算方法

PS 你如果有24GB以上显存的显卡的话,那就当我什么都没说。

当然还有些不是方法的方法:降低Batch的大小,减小模型的超参数,不需要计算显存时用torch.no_grad():等。

此外,代码变量之间的不当依赖,会导致显存持续增长,具体情况详见显存持续缓慢增长的究极原因 - 知乎 (zhihu.com)

相关星际资讯

美服王者荣耀叫什么
365bet足球平台

美服王者荣耀叫什么

🕒 07-13 👁️ 6141
云顶之弈挂机有没有惩罚?LOL云顶之弈挂机会被封号吗?
没磕没碰,为什么身上总有淤青?这几种情况要小心