Skip to content

Commit

Permalink
intra fid
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterouZh committed Jun 23, 2020
1 parent 7a64ce4 commit 7969ace
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 93 deletions.
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ which should be put into the directories mentioned above.

## Evaluate the models reported in the paper

Download the pre-trained models [onedrive](https://sjtueducn-my.sharepoint.com/:f:/g/personal/zhoupengcv_sjtu_edu_cn/EsokPqpwPMhPi8IjPh8WQBoBQF9S1iunCj-EdpawvjyyHQ?e=hAphaG), and put them into *./datasets/nas_cgan/models*.
Download the pre-trained models [onedrive](https://sjtueducn-my.sharepoint.com/:f:/g/personal/zhoupengcv_sjtu_edu_cn/EsokPqpwPMhPi8IjPh8WQBoBQF9S1iunCj-EdpawvjyyHQ?e=owFvIe),
and put them into *./datasets/nas_cgan/models*.
### FID and IS
Eval NAS-cGAN on CIFAR10.
```bash
Expand All @@ -78,11 +79,30 @@ python exp/nas_cgan/scripts/train_net.py \

### intra FIDs

Pre-calculate FID statistics on CIFAR10.
Pre-calculate intra FID statistic of each class on CIFAR10.
Or you can use our pre-calculated files [onedrive](https://sjtueducn-my.sharepoint.com/:f:/g/personal/zhoupengcv_sjtu_edu_cn/EhWbm-z9lLJDpcZ5KuqmfO0Bd5ak80J5QBT_G3y6zkYdEw?e=PR3VEF),
and put these files in *./datasets/nas_cgan/tf_fid_stat/cifar10_train_per_class_32*.
```bash

export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64:/usr/local/cudnn-10.0-v7.6.5.32/lib64:$LD_LIBRARY_PATH
export CUDA_VISIBLE_DEVICES=0
export PYTHONPATH=.:./exp
python exp/nas_cgan/scripts/train_net.py \
--config exp/nas_cgan/configs/calculate_fid_stat_per_class_CIFAR10.yaml \
--command calculate_fid_stat_per_class_CIFAR10 \
--outdir results/calculate_fid_stat_per_class_CIFAR10
```

Eval intra FIDs for NAS-cGAN.
```bash
export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64:/usr/local/cudnn-10.0-v7.6.5.32/lib64:$LD_LIBRARY_PATH
export CUDA_VISIBLE_DEVICES=0
export PYTHONPATH=.:./exp
python exp/nas_cgan/scripts/train_net.py \
--config exp/nas_cgan/configs/eval_intra_FID_NAS_cGAN_CIFAR10.yaml \
--command eval_intra_FID_NAS_cGAN_CIFAR10 \
--outdir results/eval_intra_FID_NAS_cGAN_CIFAR10
```

## Acknowledgement

1. https://github.com/facebookresearch/detectron2
Expand Down
20 changes: 20 additions & 0 deletions exp/nas_cgan/configs/calculate_fid_stat_per_class_CIFAR10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
calculate_fid_stat_per_class_CIFAR10:
args:
num_gpus: 1
start:
name: compute_fid_stats_per_class
imagenet_root_dir: datasets/cifar10
dataset_name: cifar10_train_per_class
IMS_PER_BATCH: 50
img_size: 32
NUM_WORKERS: 0
dataset_mapper:
name: CIFAR10DatasetMapper
img_size: kwargs['img_size']
GAN_metric:
names:
# - PyTorchFIDISScore
- TFFIDISScore
torch_fid_stat: datasets/nas_cgan/pytorch_fid_stat/{dataset_name}_{img_size}/
tf_fid_stat: datasets/nas_cgan/tf_fid_stat/{dataset_name}_{img_size}
tf_inception_model_dir: datasets/nas_cgan/tf_inception_model
137 changes: 137 additions & 0 deletions exp/nas_cgan/configs/eval_intra_FID_NAS_cGAN_CIFAR10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
eval_intra_FID_NAS_cGAN_CIFAR10:
args:
num_gpus: 1
resume: false
start:
name: do_train
dataset_name: cifar10_train
IMS_PER_BATCH: 32
max_epoch: 1000
ASPECT_RATIO_GROUPING: false
NUM_WORKERS: 0
checkpoint_period: iter_every_epoch
dataset_mapper:
name: CIFAR10DatasetMapper
img_size: 32
run_func: compute_intra_FID
trainer:
name: TrainerRetrainConditional
iter_every_epoch: kwargs['iter_every_epoch']
n_classes: 10
img_size: kwargs['img_size']
train_bs: kwargs['train_bs']
eval_torch_every_itr: 10*iter_every_epoch
train_controller_every_iter: 50
eval_tf_every_itr: 10*iter_every_epoch
fixed_arc_file: exp/nas_cgan/configs/cond_arcs_cifar10.log
fixed_epoch: 0
use_ema: true
load_model: false
ckpt_dir: results/CondDenseGAN_CIFAR10_v5/retrain_20200401-21_33_53_332/detectron2
ckpt_epoch: 500
ckpt_iter_every_epoch: 781
GAN_metric:
names:
- TFFIDISScore
torch_fid_stat: datasets/nas_cgan/pytorch_fid_stat/fid_stats_pytorch_cifar10_train_32.npz
tf_fid_stat: datasets/nas_cgan/tf_fid_stat/fid_stats_tf_cifar10_train_32.npz
tf_inception_model_dir: datasets/nas_cgan/tf_inception_model
num_inception_images: 50000
generator:
name: DenseGeneratorCBN_v2
ch: 512
linear_ch: 128
bottom_width: 4
dim_z: 256
init_type: orth
embedding_dim: 128
cfg_upsample:
name: UpSample
num_cells: 3
cfg_cell:
name: DenseBlock
n_nodes: 2
cfg_mix_layer:
name: MixedLayer
cfg_out_bn:
name: NoNorm
cfg_ops:
None:
name: D2None
Conv2d_3x3:
name: ActConv2d
cfg_act:
name: ReLU
cfg_conv:
name: Conv2d
kernel_size: 3
padding: 1
StyleV2Conv:
name: StyleV2Conv
cfg_modconv:
name: ModulatedConv2d
kernel_size: 3
style_dim: 192
optimizer:
name: Adam
lr: 0.0001
betas:
- 0.0
- 0.9
eps: 1.0e-08
controller:
name: FairController
num_layers: kwargs['num_layers']
num_branches: kwargs['num_branches']
optimizer:
name: NoneOptim
discriminator:
name: AutoGANCIFAR10ADiscriminatorCProj
ch: 256
d_spectral_norm: true
init_type: orth
cfg_act:
name: ReLU
optimizer:
name: Adam
lr: 0.0001
betas:
- 0.0
- 0.9
eps: 1.0e-08
noise:
z_train:
name: Normal
loc: 0
scale: 1
sample_shape: kwargs['sample_shape']
num_ops: kwargs['num_ops']
y_train:
name: CategoricalUniform
n_classes: kwargs['n_classes']
sample_shape: kwargs['sample_shape']
num_ops: kwargs['num_ops']
z_test:
name: Normal
loc: 0
scale: 1
sample_shape: kwargs['sample_shape']
y_test:
name: CategoricalUniform
n_classes: kwargs['n_classes']
sample_shape: kwargs['sample_shape']
GAN_model:
name: HingeLossCond
dummy: false
n_critic: 5
log_every: 50
dataset:
dataset_mapper: CIFAR10DatasetMapper
img_size: 32
compute_intra_FID:
ckpt_path: datasets/nas_cgan/models/nas_cgan_cifar10.pth
registed_name: cifar10_train_per_class
fid_stats_dir: datasets/nas_cgan/tf_fid_stat/cifar10_train_per_class_32
num_inception_images: 5000
intra_FID_file: intra_FID.npz
eval_total_FID: true
92 changes: 7 additions & 85 deletions exp/nas_cgan/models/trainer_nasgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,44 +598,6 @@ def train_func(self, data, iteration, pbar):
self.evaluate_model(classes_arcs=classes_arcs, iteration=iteration)
comm.synchronize()

def derive_arcs(self):
from numpy import array, array_equal, allclose
fixed_arc_file = get_attr_kwargs(self.cfg.derive_arcs, 'fixed_arc_file')
num_rows = get_attr_kwargs(self.cfg.derive_arcs, 'num_rows', default=self.n_classes)

counter = 0
arc_list = []
try:
n_lines = rawgencount(fixed_arc_file)
print(n_lines)
with open(fixed_arc_file) as f:
while True:
if counter >= n_lines:
break
print(f"[{counter}/{n_lines}]", flush=True)

iteration = int(f.readline().strip(': \n'))
counter += 1
sample_arc = []
for _ in range(num_rows):
class_arc = f.readline().strip('[\n ]')
counter += 1
sample_arc.append(np.fromstring(class_arc, dtype=int, sep=' '))
sample_arc = np.array(sample_arc)
if array_eq_in_list(sample_arc, arc_list):
continue
arc_list.append(sample_arc)

classes_arcs = torch.from_numpy(sample_arc)
if len(classes_arcs) == 1:
classes_arcs = classes_arcs.repeat(self.n_classes, 1)
self.evaluate_model(classes_arcs=classes_arcs, iteration=iteration*self.iter_every_epoch)

except:
import traceback
print(traceback.format_exc())
print('End.')
pass

@TRAINER_REGISTRY.register()
class TrainerCondController(TrainerSupernetCondController):
Expand Down Expand Up @@ -826,18 +788,19 @@ def compute_intra_FID(self, ):

cfg = self.cfg.compute_intra_FID

ckpt_dir = get_attr_kwargs(cfg, 'ckpt_dir')
ckpt_epoch = get_attr_kwargs(cfg, 'ckpt_epoch')
ckpt_iter_every_epoch = get_attr_kwargs(cfg, 'ckpt_iter_every_epoch')
ckpt_path = get_attr_kwargs(cfg, 'ckpt_path', default=None)
ckpt_dir = get_attr_kwargs(cfg, 'ckpt_dir', default='')
ckpt_epoch = get_attr_kwargs(cfg, 'ckpt_epoch', default=0)
ckpt_iter_every_epoch = get_attr_kwargs(cfg, 'ckpt_iter_every_epoch', default=0)
registed_name = get_attr_kwargs(cfg, 'registed_name')
fid_stats_dir = get_attr_kwargs(cfg, 'fid_stats_dir')
num_inception_images = get_attr_kwargs(cfg, 'num_inception_images')
intra_FID_file = get_attr_kwargs(cfg, 'intra_FID_file')
eval_total_FID = get_attr_kwargs(cfg, 'eval_total_FID', default=True)


ckpt_path = self._get_ckpt_path(ckpt_dir=ckpt_dir, ckpt_epoch=ckpt_epoch,
iter_every_epoch=ckpt_iter_every_epoch)
if not ckpt_path:
ckpt_path = self._get_ckpt_path(ckpt_dir=ckpt_dir, ckpt_epoch=ckpt_epoch,
iter_every_epoch=ckpt_iter_every_epoch)
self._load_model(ckpt_path)

classes_arcs = self.arcs
Expand Down Expand Up @@ -1090,44 +1053,3 @@ def after_resume(self):
pass


@TRAINER_REGISTRY.register()
class TrainerRetrainEnsembleConditional(TrainerRetrainConditional):

def __init__(self, cfg, **kwargs):
super().__init__(cfg=cfg, **kwargs)

self.num_ensemble_arcs = get_attr_kwargs(cfg.trainer, 'num_ensemble_arcs', **kwargs)
pass

def train_func(self, data, iteration, pbar):
images, labels = self.preprocess_image(data)
images = images.tensor

bs = len(images)

if self.num_ensemble_arcs == self.n_classes:
batched_arcs = self.arcs.repeat((bs, 1))
elif self.num_ensemble_arcs == 1:
# idx = iteration % self.n_classes
idx = torch.randint(self.n_classes, (1,))
batched_arcs = self.arcs[idx].repeat((bs, 1))
else:
arc_idx = [labels]
for i in range(self.num_ensemble_arcs - 1):
idx = torch.randperm(labels.nelement())
arc_idx.append(labels[idx])
arc_idx = torch.stack(arc_idx, dim=1).view(-1)
batched_arcs = self.arcs[arc_idx]

images = torch.repeat_interleave(images, self.num_ensemble_arcs, 0)
labels = torch.repeat_interleave(labels, self.num_ensemble_arcs, 0)

self.gan_model(images=images, labels=labels, z=self.z_train, iteration=iteration, batched_arcs=batched_arcs,
ema=self.ema, max_iter=self.max_iter)

# Just for monitoring the training processing
classes_arcs = self.arcs
self.evaluate_model(classes_arcs=classes_arcs, iteration=iteration)
comm.synchronize()


4 changes: 0 additions & 4 deletions exp/nas_cgan/scripts/compute_inception_moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def compute_fid_stats_per_class(cfg, args, myargs):
from template_lib.d2.data.build_cifar10_per_class import find_classes
elif dataset_name.startswith('cifar100_train_per_class'):
from template_lib.d2.data.build_cifar100_per_class import find_classes
elif dataset_name.startswith('imagenet'):
# register all class of ImageNet for dataloader
from template_lib.d2.data.build_ImageNet_per_class import ImageNetDatasetPerClassMapper
from template_lib.d2.data.BigGAN import find_classes

torch_fid_stat = torch_fid_stat.format(dataset_name=dataset_name, img_size=img_size)
tf_fid_stat = tf_fid_stat.format(dataset_name=dataset_name, img_size=img_size)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pycocotools
easydict
scipy
imageio
tensorboardX
tensorboardX
opencv-python

0 comments on commit 7969ace

Please sign in to comment.