Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ cpp/mnist/build
cpp/dcgan/build
dcgan/*.png
dcgan/*.pth
distributed/dcgan/*.png
distributed/dcgan/*.pth
snli/.data
snli/.vector_cache
snli/results
Expand All @@ -23,3 +25,4 @@ docs/venv
# development
.vscode
**/.DS_Store

4 changes: 2 additions & 2 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, ngpu):

def forward(self, input):

if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
if (input.is_cuda) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
Expand Down Expand Up @@ -192,7 +192,7 @@ def __init__(self, ngpu):
)

def forward(self, input):
if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
if (input.is_cuda) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
Expand Down
63 changes: 63 additions & 0 deletions distributed/dcgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Deep Convolution Generative Adversarial Networks

This example implements the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434)

The implementation is very close to the Torch implementation [dcgan.torch](https://github.com/soumith/dcgan.torch)

After every 100 training iterations, the files `real_samples.png` and `fake_samples.png` are written to disk
with the samples from the generative model.

After every epoch, models are saved to: `netG_epoch_%d.pth` and `netD_epoch_%d.pth`

## Downloading the dataset

You can download the LSUN dataset by cloning [this repo](https://github.com/fyu/lsun) and running

```
python download.py -c bedroom
```

## Installation

```bash
pip install -r requirements.txt
```

## Running Examples

You can run the examples using `torchrun` to launch distributed training:

```bash
torchrun --nnodes=1 --nproc_per_node=4 main.py --dataset fake
```

For more details, check the `run_examples.sh` script.

## Usage

```
usage: main.py [-h] --dataset DATASET [--dataroot DATAROOT] [--workers WORKERS]
[--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ] [--niter NITER]
[--lr LR] [--beta1 BETA1] [--dry-run] [--ngf NGF] [--ndf NDF] [--netG NETG]
[--netD NETD] [--outf OUTF] [--manualSeed MANUALSEED] [--classes CLASSES]

options:
-h, --help show this help message and exit
--dataset DATASET cifar10 | lsun | mnist |imagenet | folder | lfw | fake
--dataroot DATAROOT path to dataset
--workers WORKERS number of data loading workers
--batchSize BATCHSIZE input batch size
--imageSize IMAGESIZE the height / width of the input image to network
--nz NZ size of the latent z vector
--niter NITER number of epochs to train for
--lr LR learning rate, default=0.0002
--beta1 BETA1 beta1 for adam. default=0.5
--dry-run check a single training cycle works
--ngf NGF
--ndf NDF
--netG NETG path to netG (to continue training)
--netD NETD path to netD (to continue training)
--outf OUTF folder to output images and model checkpoints
--manualSeed MANUALSEED manual seed
--classes CLASSES comma separated list of classes for the lsun data set
```
Loading