Skip to content
View liguiming77's full-sized avatar

Block or report liguiming77

Block user

Prevent this user from interacting with your repositories and sending you notifications. Learn more about blocking users.

You must be logged in to block users.

Please don't include any personal information such as legal names or email addresses. Maximum 100 characters, markdown supported. This note will be visible to only you.
Report abuse

Contact GitHub support about this user’s behavior. Learn more about reporting abuse.

Report abuse
liguiming77/README.md

efficient_densenet_pytorch

A PyTorch >=1.0 implementation of DenseNets, optimized to save GPU memory.

Recent updates

  1. Now works on PyTorch 1.0! It uses the checkpointing feature, which makes this code WAY more efficient!!!

Motivation

While DenseNets are fairly easy to implement in deep learning frameworks, most implmementations (such as the original) tend to be memory-hungry. In particular, the number of intermediate feature maps generated by batch normalization and concatenation operations grows quadratically with network depth. It is worth emphasizing that this is not a property inherent to DenseNets, but rather to the implementation.

This implementation uses a new strategy to reduce the memory consumption of DenseNets. We use checkpointing to compute the Batch Norm and concatenation feature maps. These intermediate feature maps are discarded during the forward pass and recomputed for the backward pass. This adds 15-20% of time overhead for training, but reduces feature map consumption from quadratic to linear.

This implementation is inspired by this technical report, which outlines a strategy for efficient DenseNets via memory sharing.

Requirements

  • PyTorch >=1.0.0
  • CUDA

Usage

In your existing project: There is one file in the models folder.

If you care about speed, and memory is not an option, pass the efficient=False argument into the DenseNet constructor. Otherwise, pass in efficient=True.

Options:

  • All options are described in the docstrings of the model files
  • The depth is controlled by block_config option
  • efficient=True uses the memory-efficient version
  • If you want to use the model for ImageNet, set small_inputs=False. For CIFAR or SVHN, set small_inputs=True.

Running the demo:

The only extra package you need to install is python-fire:

pip install fire
  • Single GPU:
CUDA_VISIBLE_DEVICES=0 python demo.py --efficient True --data <path_to_folder_with_cifar10> --save <path_to_save_dir>
  • Multiple GPU:
CUDA_VISIBLE_DEVICES=0,1,2 python demo.py --efficient True --data <path_to_folder_with_cifar10> --save <path_to_save_dir>

Options:

  • --depth (int) - depth of the network (number of convolution layers) (default 40)
  • --growth_rate (int) - number of features added per DenseNet layer (default 12)
  • --n_epochs (int) - number of epochs for training (default 300)
  • --batch_size (int) - size of minibatch (default 256)
  • --seed (int) - manually set the random seed (default None)

Performance

A comparison of the two implementations (each is a DenseNet-BC with 100 layers, batch size 64, tested on a NVIDIA Pascal Titan-X):

Implementation Memory cosumption (GB/GPU) Speed (sec/mini batch)
Naive 2.863 0.165
Efficient 1.605 0.207
Efficient (multi-GPU) 0.985 -

Other efficient implementations

Reference

@article{pleiss2017memory,
  title={Memory-Efficient Implementation of DenseNets},
  author={Pleiss, Geoff and Chen, Danlu and Huang, Gao and Li, Tongcheng and van der Maaten, Laurens and Weinberger, Kilian Q},
  journal={arXiv preprint arXiv:1707.06990},
  year={2017}
}

Popular repositories Loading

  1. liguiming77 liguiming77 Public

    Python

  2. lstm lstm Public

    Forked from nicodjimenez/lstm

    Minimal, clean example of lstm neural network training in python, for learning purposes.

    Python

  3. LSTM-Neural-Network-for-Time-Series-Prediction LSTM-Neural-Network-for-Time-Series-Prediction Public

    Forked from jaungiers/LSTM-Neural-Network-for-Time-Series-Prediction

    LSTM built using Keras Python package to predict time series steps and sequences. Includes sin wave and stock market data

    Python

  4. learning-tensorflow learning-tensorflow Public

    Forked from Salon-sai/learning-tensorflow

    Using tensorflow to implement different deep learning model

    Python

  5. SQUAD2.Q-Augmented-Dataset SQUAD2.Q-Augmented-Dataset Public

    Forked from ankit-ai/SQUAD2.Q-Augmented-Dataset

    Augmented version of SQUAD 2.0 for Questions

    Python

  6. BertQA-Attention-on-Steroids BertQA-Attention-on-Steroids Public

    Forked from ankit-ai/BertQA-Attention-on-Steroids

    BertQA - Attention on Steroids

    Jupyter Notebook