解决pytorch在训练时由于设置了验证集导致out of memory(同样可用于测试时减少显存占用)

问题描述:


最近一直在使用pytorch, 由于深度学习的网络往往需要设置验证集来验证模型是否稳定.

我一直再做一个关于医学影像分割的课题,为了查看自己的模型是否稳定,于是设置了验证集.

但是在运行的过程中,当程序执行到 validatioon时,显存立即上升,我可怜的显卡只有8GB显存,瞬间爆炸.

怎么办呢?实验得做呀.于是找了不少方法,比如设置各个网络变量requires_grad=False,但是并不管用,显存依然爆炸.

后来百度了一番,终于解决了显存爆炸的问题.

解决方案:


假设训练程序是这样的:

for train_data, train_label in  train_dataloader:

    do 

           trainning

then

for valid_data,valid_label in valid_dataloader:

    do 

            validtion

当程序执行到validation时,显存忽然上升,几乎是之前的两倍.


只需要这样改:

for train_data, train_label in train_dataloader:

        do

            trainning


then

with torch.no_grad():

    for valid_data,valid_label in valid_dataloader:

            do

                validtion

当程序执行到validation时,显存将不再上升.问题得到解决.真的是非常简单.

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容