-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5bda989
Showing
99 changed files
with
6,215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
checkpoints/* | ||
results/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
## Training on a custom dataset | ||
|
||
### Data organization | ||
|
||
There are two components consisting of a training dataset: | ||
|
||
1. Multi-view images of training scenes. | ||
2. (Inverted) Extrinsic camera matrix corresponding to each image. | ||
|
||
Please organize them in the following format: | ||
|
||
```python | ||
image_name = '{}sc{:04d}{}.png'.format(some_string, scene_id, some_other_string) | ||
corresponding_extrinsic_matrix_name = '{}sc{:04d}{}_RT.txt'.format(some_string, scene_id, some_other_string) | ||
``` | ||
|
||
where ```scene_id``` is the only effective factor for the dataloader to recognize and associate images/extrinsics of a specific scene. | ||
It should range from ```0``` to ```n_scenes-1``` (```n_scenes``` is specified in the running script). | ||
Here we assume no more than 10,000 scenes in the training set by specifying ```"{:04d}"```. | ||
If you have more, simply specify a larger number. | ||
```some_string``` and ```some_other_string``` can be anything that do not include ```"sc{:04d}"```. | ||
|
||
An example is: | ||
``` | ||
00000_sc0000_img0.png | ||
00000_sc0000_img0_RT.txt | ||
00001_sc0000_img1.png | ||
00001_sc0000_img1_RT.txt | ||
00002_sc0001_img0.png | ||
00002_sc0001_img0_RT.txt | ||
... | ||
``` | ||
|
||
An (inverted) extrinsic matrix should be a camera-to-world transform | ||
which is composed by camera pose (rotation and translation). | ||
An extrinsic matrix should be something like this: | ||
``` | ||
0.87944 -0.30801 0.36292 -4.49662 | ||
-0.47601 -0.56906 0.67051 -8.30771 | ||
-0.00000 -0.76243 -0.64707 8.01728 | ||
0.00000 0.00000 0.00000 1.00000 | ||
``` | ||
where the uppper-left block is camera orientation matrix and the upper-right vector is the camera location. | ||
Here we assume the origin of world frame is located at the center of the scene, | ||
as we hard-code to put the locality box centered at the world origin | ||
(it is used to enforce zero density outside the box for foreground slots during early training, | ||
to help better separate foreground and background). | ||
|
||
### Intrinsic parameters | ||
|
||
Besides extrinsics, you will also need to specify camera intrinsics in [../models/projection.py](../models/projection.py) by specifying | ||
its ```focal_ratio``` (expressed as focal length divided by number of pixels along X/width axis and Y/height axis). | ||
|
||
### Script modification | ||
|
||
When you use the training and evaluating scripts in ```/scripts```, | ||
additionally specify ```--fixed_locality```. | ||
|
||
According to the scale of your scene (measured by world coordinate units), | ||
you need to modify | ||
1. the near and far plane (```--near``` and ```--far```) for ray marching | ||
2. the rough scale (does not need to be accurate) of your scene (```--nss_scale```) for normalizing input coordinates. | ||
3. the scale that roughly includes foreground objects (```--obj_scale```) for computing locality box. This should not be too small | ||
such that the box occupied too few pixels (e.g., <80%) when projected to image planes. | ||
|
||
### Tips for training on a custom dataset | ||
|
||
The background-aware attention brings the advantage of separating background and foreground objects, | ||
allowing cleaner segmentation and scene manipulation. | ||
But it also trades some stability off. | ||
It seems that | ||
the unsupervised separation of fg/bg somehow makes the module more susceptible to some sort of [rank-collapse problem](https://arxiv.org/abs/2103.03404). | ||
There are two kinds of rank-collapse I have seen during experiments. | ||
|
||
- The foreground-background separation is good and the foreground slots are explaining objects, but the foreground slots are few-rank (i.e., some foreground slots decode the same "dim version" of several objects) or even 1-rank. | ||
|
||
- All the foreground slots decode nothing. In this case you would see zero gradient for all layers of foreground decoder (you can see this from the visdom panel). | ||
|
||
In the first case, | ||
adding more training scenes or simply change a random seed is enough. | ||
|
||
In the second case, bad learning dynamics leads to zero foreground density even inside the locality boxes early in training, and the foreground decoder is effectively killed. | ||
Changing random seeds might provide a better initialization and thus bypass the problem, | ||
and you might consider tweaking hyper-parameters related to learning rate scheduler, such as learning rate ```--lr```, its decay speed ```--attn_decay_steps``` and warm-up steps ```--warmup_steps```. | ||
Adding some noise to foreground decoder outputs could also help. | ||
Hopefully future research can shed light on this problem and solve it once for all. | ||
|
||
|
||
## Generate your own dataset | ||
|
||
We provide an example generation assets and codebase at | ||
[here](https://office365stanford-my.sharepoint.com/:u:/g/personal/koven_stanford_edu/Ec-vEV0XMxBGpWgx1y6kSkIBOiY_AelngVf2qk2zAHgb_A?e=2gIqGv) | ||
for the object shape models from ShapeNet and | ||
[here](https://office365stanford-my.sharepoint.com/:f:/g/personal/koven_stanford_edu/EnJcGIJ1dadJqIRRth43dIwBVcf-5Um9yotNs2HYOqgLDA?e=6lZWCJ) | ||
for the codebase and textures of the Room Diverse dataset. | ||
|
||
In ``/image_generation/scripts`` run ``generate_1200shape_50bg.sh`` and then ``render_1200_shapes.sh``. | ||
Don't forget to change the root directory in both scripts. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""This package includes all the modules related to data loading and preprocessing | ||
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. | ||
You need to implement four functions: | ||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). | ||
-- <__len__>: return the size of dataset. | ||
-- <__getitem__>: get a data point from data loader. | ||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. | ||
Now you can use the dataset class by specifying flag '--dataset_mode dummy'. | ||
See our template dataset class 'template_dataset.py' for more details. | ||
""" | ||
import importlib | ||
import torch.utils.data | ||
from data.base_dataset import BaseDataset | ||
|
||
|
||
def find_dataset_using_name(dataset_name): | ||
"""Import the module "data/[dataset_name]_dataset.py". | ||
In the file, the class called DatasetNameDataset() will | ||
be instantiated. It has to be a subclass of BaseDataset, | ||
and it is case-insensitive. | ||
""" | ||
dataset_filename = "data." + dataset_name + "_dataset" | ||
datasetlib = importlib.import_module(dataset_filename) | ||
|
||
dataset = None | ||
target_dataset_name = dataset_name.replace('_', '') + 'dataset' | ||
for name, cls in datasetlib.__dict__.items(): | ||
if name.lower() == target_dataset_name.lower() \ | ||
and issubclass(cls, BaseDataset): | ||
dataset = cls | ||
|
||
if dataset is None: | ||
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) | ||
|
||
if 'collate_fn' in datasetlib.__dict__.keys(): | ||
return dataset, datasetlib.__dict__['collate_fn'] | ||
|
||
return dataset | ||
|
||
|
||
def get_option_setter(dataset_name): | ||
"""Return the static method <modify_commandline_options> of the dataset class.""" | ||
ret = find_dataset_using_name(dataset_name) | ||
dataset_class = ret[0] if type(ret) == tuple else ret | ||
return dataset_class.modify_commandline_options | ||
|
||
|
||
def create_dataset(opt): | ||
"""Create a dataset given the option. | ||
This function wraps the class CustomDatasetDataLoader. | ||
This is the main interface between this package and 'train.py'/'test.py' | ||
Example: | ||
>>> from data import create_dataset | ||
>>> dataset = create_dataset(opt) | ||
""" | ||
data_loader = CustomDatasetDataLoader(opt) | ||
dataset = data_loader.load_data() | ||
return dataset | ||
|
||
|
||
class CustomDatasetDataLoader(): | ||
"""Wrapper class of Dataset class that performs multi-threaded data loading""" | ||
|
||
def __init__(self, opt): | ||
"""Initialize this class | ||
Step 1: create a dataset instance given the name [dataset_mode] | ||
Step 2: create a multi-threaded data loader. | ||
""" | ||
self.opt = opt | ||
ret = find_dataset_using_name(opt.dataset_mode) | ||
self.collate_fn = None | ||
if type(ret) == tuple: | ||
dataset_class, self.collate_fn = ret | ||
else: | ||
dataset_class = ret | ||
self.dataset = dataset_class(opt) | ||
print("dataset [%s] was created" % type(self.dataset).__name__) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batch_size, | ||
shuffle=not opt.serial_batches, | ||
num_workers=int(opt.num_threads), | ||
collate_fn = self.collate_fn) | ||
|
||
def load_data(self): | ||
return self | ||
|
||
def __len__(self): | ||
"""Return the number of data in the dataset""" | ||
return min(len(self.dataset), self.opt.max_dataset_size) | ||
|
||
def __iter__(self): | ||
"""Return a batch of data""" | ||
for i, data in enumerate(self.dataloader): | ||
if i * self.opt.batch_size >= self.opt.max_dataset_size: | ||
break | ||
yield data | ||
|
||
def filter_objects(self, n): | ||
self.dataset.filter_objects(n) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=self.opt.batch_size, | ||
shuffle=not self.opt.serial_batches, | ||
num_workers=int(self.opt.num_threads), | ||
collate_fn = self.collate_fn) | ||
|
||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. | ||
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. | ||
""" | ||
import random | ||
import numpy as np | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseDataset(data.Dataset, ABC): | ||
"""This class is an abstract base class (ABC) for datasets. | ||
To create a subclass, you need to implement the following four functions: | ||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). | ||
-- <__len__>: return the size of dataset. | ||
-- <__getitem__>: get a data point. | ||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. | ||
""" | ||
|
||
def __init__(self, opt): | ||
"""Initialize the class; save the options in the class | ||
Parameters: | ||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions | ||
""" | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
|
||
@staticmethod | ||
def modify_commandline_options(parser, is_train): | ||
"""Add new dataset-specific options, and rewrite default values for existing options. | ||
Parameters: | ||
parser -- original option parser | ||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. | ||
Returns: | ||
the modified parser. | ||
""" | ||
return parser | ||
|
||
@abstractmethod | ||
def __len__(self): | ||
"""Return the total number of images in the dataset.""" | ||
return 0 | ||
|
||
@abstractmethod | ||
def __getitem__(self, index): | ||
"""Return a data point and its metadata information. | ||
Parameters: | ||
index - - a random integer for data indexing | ||
Returns: | ||
a dictionary of data with their names. It ususally contains the data itself and its metadata information. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def set_epoch(self, epoch): | ||
"""Set the epoch for this dataset | ||
Parameters: | ||
epoch - - the epoch number | ||
""" | ||
pass | ||
|
||
|
||
def get_params(opt, size): | ||
w, h = size | ||
new_h = h | ||
new_w = w | ||
if opt.preprocess == 'resize_and_crop': | ||
new_h = new_w = opt.load_size | ||
elif opt.preprocess == 'scale_width_and_crop': | ||
new_w = opt.load_size | ||
new_h = opt.load_size * h // w | ||
|
||
x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) | ||
y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) | ||
|
||
flip = random.random() > 0.5 | ||
|
||
return {'crop_pos': (x, y), 'flip': flip} | ||
|
||
|
||
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): | ||
transform_list = [] | ||
if grayscale: | ||
transform_list.append(transforms.Grayscale(1)) | ||
if 'resize' in opt.preprocess: | ||
osize = [opt.load_size, opt.load_size] | ||
transform_list.append(transforms.Resize(osize, method)) | ||
elif 'scale_width' in opt.preprocess: | ||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) | ||
|
||
if 'crop' in opt.preprocess: | ||
if params is None: | ||
transform_list.append(transforms.RandomCrop(opt.crop_size)) | ||
else: | ||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) | ||
|
||
if opt.preprocess == 'none': | ||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) | ||
|
||
if not opt.no_flip: | ||
if params is None: | ||
transform_list.append(transforms.RandomHorizontalFlip()) | ||
elif params['flip']: | ||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) | ||
|
||
if convert: | ||
transform_list += [transforms.ToTensor()] | ||
if grayscale: | ||
transform_list += [transforms.Normalize((0.5,), (0.5,))] | ||
else: | ||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | ||
return transforms.Compose(transform_list) | ||
|
||
|
||
def __make_power_2(img, base, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
h = int(round(oh / base) * base) | ||
w = int(round(ow / base) * base) | ||
if (h == oh) and (w == ow): | ||
return img | ||
|
||
__print_size_warning(ow, oh, w, h) | ||
return img.resize((w, h), method) | ||
|
||
|
||
def __scale_width(img, target_width, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
if (ow == target_width): | ||
return img | ||
w = target_width | ||
h = int(target_width * oh / ow) | ||
return img.resize((w, h), method) | ||
|
||
|
||
def __crop(img, pos, size): | ||
ow, oh = img.size | ||
x1, y1 = pos | ||
tw = th = size | ||
if (ow > tw or oh > th): | ||
return img.crop((x1, y1, x1 + tw, y1 + th)) | ||
return img | ||
|
||
|
||
def __flip(img, flip): | ||
if flip: | ||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
return img | ||
|
||
|
||
def __print_size_warning(ow, oh, w, h): | ||
"""Print warning information about image size(only print once)""" | ||
if not hasattr(__print_size_warning, 'has_printed'): | ||
print("The image size needs to be a multiple of 4. " | ||
"The loaded image size was (%d, %d), so it was adjusted to " | ||
"(%d, %d). This adjustment will be done to all images " | ||
"whose sizes are not multiples of 4" % (ow, oh, w, h)) | ||
__print_size_warning.has_printed = True |
Oops, something went wrong.