第三课 关于AlexNet内存爆炸问题

1.课程地址

http://zh.gluon.ai/chapter_convolutional-neural-networks/alexnet-gluon.html

2.解决

原因:来自https://discuss.gluon.ai/t/topic/3792 xiaoming

内存炸裂是因为’load_data_fashion_mnist‘函数的原因,这个函数会把fashionMNIST数据集的所有图片都先resize,然后存储到内存里面。 你这里resize = 224,然后它会把整个数据集的60000张图片一起resize,这时候数据集的数据就有60000 * 224 * 224 * 3。这个用float32存储需要30多g的内存。

办法:来自https://discuss.gluon.ai/t/topic/1258/49 xiaoming

删除了,然后把相应的功能放到class DataLoader里了。
提醒一下:原来的transform是作为gluon.data.vision.FashionMNIST的参数的。而我将transform的操作放到class DataLoader内部,而在外部只是多加了一个resize的参数。
其实我这样写少了很多功能,万一tranform的操作需要更改的话,就要去改class DataLoader的定义了。 所以如果想实现跟gluon.data.vision.FashionMNIST的参数transform一样多的功能的话,最好把整个transform函数作为class DataLoader的一个参数,然后可以在 yield里调用这个transform。
如下修改:

class DataLoader(object):
    """similiar to gluon.data.DataLoader, but might be faster.

    The main difference this data loader tries to read more exmaples each
    time. But the limits are 1) all examples in dataset have the same shape, 2)
    data transfomer needs to process multiple examples at each time
    """
    def __init__(self, dataset, batch_size, shuffle, transform):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.transform = transform

    def __iter__(self):
        data = self.dataset[:]
        X = data[0]
        y = nd.array(data[1])
        n = X.shape[0]
        if self.shuffle:
            idx = np.arange(n)
            np.random.shuffle(idx)
            X = nd.array(X.asnumpy()[idx])
            y = nd.array(y.asnumpy()[idx])

        for i in range(n//self.batch_size):
            if self.transform is not None:
                yield self.transform(X[i*self.batch_size:(i+1)*self.batch_size], 
                                     y[i*self.batch_size:(i+1)*self.batch_size])
            else:
                yield (X[i*self.batch_size:(i+1)*self.batch_size],
                       y[i*self.batch_size:(i+1)*self.batch_size])

    def __len__(self):
        return len(self.dataset)//self.batch_size

def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
    """download the fashion mnist dataest and then load into memory"""
    def transform_mnist(data, label):
        # transform a batch of examples
        if resize:
            n = data.shape[0]
            new_data = nd.zeros((n, resize, resize, data.shape[3]))
            for i in range(n):
                new_data[i] = image.imresize(data[i], resize, resize)
            data = new_data
        # change data from batch x height x weight x channel to batch x channel x height x weight
        return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')
    
    mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
    mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
    train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform = transform_mnist)
    test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform = transform_mnist)
    return (train_data, test_data)

参考地址:
https://discuss.gluon.ai/t/topic/3792
https://discuss.gluon.ai/t/topic/1258/45

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • Spring Cloud为开发人员提供了快速构建分布式系统中一些常见模式的工具(例如配置管理,服务发现,断路器,智...
    卡卡罗2017阅读 136,044评论 19 139
  • 今天实在不知道写什么,找某人给我随便出了个题目,无奈胡诌几句好了。 看到这个题目就想起来张爱玲那句被人用到烂俗的话...
    唯见月寒日暖阅读 3,468评论 0 0
  • 我认为精彩的人生要有不断的提升,没有提升的人生不值得一过。 到了人生的某一阶段,好像生命静止了,没有热烈的追求,只...
    夏林鹿阅读 1,441评论 0 0
  • 今天室友说:“如果你再这样,你将会失去我这个宝宝。”虽然是玩笑的口吻,却让我想起了邹恒然和李琳伊。那还是个智能手机...
    熙兮晚归阅读 1,765评论 0 0

友情链接更多精彩内容