Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Fairy committed Feb 3, 2024
0 parents commit 5bda989
Show file tree
Hide file tree
Showing 99 changed files with 6,215 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
checkpoints/*
results/*
98 changes: 98 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
## Training on a custom dataset

### Data organization

There are two components consisting of a training dataset:

1. Multi-view images of training scenes.
2. (Inverted) Extrinsic camera matrix corresponding to each image.

Please organize them in the following format:

```python
image_name = '{}sc{:04d}{}.png'.format(some_string, scene_id, some_other_string)
corresponding_extrinsic_matrix_name = '{}sc{:04d}{}_RT.txt'.format(some_string, scene_id, some_other_string)
```

where ```scene_id``` is the only effective factor for the dataloader to recognize and associate images/extrinsics of a specific scene.
It should range from ```0``` to ```n_scenes-1``` (```n_scenes``` is specified in the running script).
Here we assume no more than 10,000 scenes in the training set by specifying ```"{:04d}"```.
If you have more, simply specify a larger number.
```some_string``` and ```some_other_string``` can be anything that do not include ```"sc{:04d}"```.

An example is:
```
00000_sc0000_img0.png
00000_sc0000_img0_RT.txt
00001_sc0000_img1.png
00001_sc0000_img1_RT.txt
00002_sc0001_img0.png
00002_sc0001_img0_RT.txt
...
```

An (inverted) extrinsic matrix should be a camera-to-world transform
which is composed by camera pose (rotation and translation).
An extrinsic matrix should be something like this:
```
0.87944 -0.30801 0.36292 -4.49662
-0.47601 -0.56906 0.67051 -8.30771
-0.00000 -0.76243 -0.64707 8.01728
0.00000 0.00000 0.00000 1.00000
```
where the uppper-left block is camera orientation matrix and the upper-right vector is the camera location.
Here we assume the origin of world frame is located at the center of the scene,
as we hard-code to put the locality box centered at the world origin
(it is used to enforce zero density outside the box for foreground slots during early training,
to help better separate foreground and background).

### Intrinsic parameters

Besides extrinsics, you will also need to specify camera intrinsics in [../models/projection.py](../models/projection.py) by specifying
its ```focal_ratio``` (expressed as focal length divided by number of pixels along X/width axis and Y/height axis).

### Script modification

When you use the training and evaluating scripts in ```/scripts```,
additionally specify ```--fixed_locality```.

According to the scale of your scene (measured by world coordinate units),
you need to modify
1. the near and far plane (```--near``` and ```--far```) for ray marching
2. the rough scale (does not need to be accurate) of your scene (```--nss_scale```) for normalizing input coordinates.
3. the scale that roughly includes foreground objects (```--obj_scale```) for computing locality box. This should not be too small
such that the box occupied too few pixels (e.g., <80%) when projected to image planes.

### Tips for training on a custom dataset

The background-aware attention brings the advantage of separating background and foreground objects,
allowing cleaner segmentation and scene manipulation.
But it also trades some stability off.
It seems that
the unsupervised separation of fg/bg somehow makes the module more susceptible to some sort of [rank-collapse problem](https://arxiv.org/abs/2103.03404).
There are two kinds of rank-collapse I have seen during experiments.

- The foreground-background separation is good and the foreground slots are explaining objects, but the foreground slots are few-rank (i.e., some foreground slots decode the same "dim version" of several objects) or even 1-rank.

- All the foreground slots decode nothing. In this case you would see zero gradient for all layers of foreground decoder (you can see this from the visdom panel).

In the first case,
adding more training scenes or simply change a random seed is enough.

In the second case, bad learning dynamics leads to zero foreground density even inside the locality boxes early in training, and the foreground decoder is effectively killed.
Changing random seeds might provide a better initialization and thus bypass the problem,
and you might consider tweaking hyper-parameters related to learning rate scheduler, such as learning rate ```--lr```, its decay speed ```--attn_decay_steps``` and warm-up steps ```--warmup_steps```.
Adding some noise to foreground decoder outputs could also help.
Hopefully future research can shed light on this problem and solve it once for all.


## Generate your own dataset

We provide an example generation assets and codebase at
[here](https://office365stanford-my.sharepoint.com/:u:/g/personal/koven_stanford_edu/Ec-vEV0XMxBGpWgx1y6kSkIBOiY_AelngVf2qk2zAHgb_A?e=2gIqGv)
for the object shape models from ShapeNet and
[here](https://office365stanford-my.sharepoint.com/:f:/g/personal/koven_stanford_edu/EnJcGIJ1dadJqIRRth43dIwBVcf-5Um9yotNs2HYOqgLDA?e=6lZWCJ)
for the codebase and textures of the Room Diverse dataset.

In ``/image_generation/scripts`` run ``generate_1200shape_50bg.sh`` and then ``render_1200_shapes.sh``.
Don't forget to change the root directory in both scripts.
114 changes: 114 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)

dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls

if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

if 'collate_fn' in datasetlib.__dict__.keys():
return dataset, datasetlib.__dict__['collate_fn']

return dataset


def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
ret = find_dataset_using_name(dataset_name)
dataset_class = ret[0] if type(ret) == tuple else ret
return dataset_class.modify_commandline_options


def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
dataset = data_loader.load_data()
return dataset


class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""

def __init__(self, opt):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
ret = find_dataset_using_name(opt.dataset_mode)
self.collate_fn = None
if type(ret) == tuple:
dataset_class, self.collate_fn = ret
else:
dataset_class = ret
self.dataset = dataset_class(opt)
print("dataset [%s] was created" % type(self.dataset).__name__)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads),
collate_fn = self.collate_fn)

def load_data(self):
return self

def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)

def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data

def filter_objects(self, n):
self.dataset.filter_objects(n)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=self.opt.batch_size,
shuffle=not self.opt.serial_batches,
num_workers=int(self.opt.num_threads),
collate_fn = self.collate_fn)


Binary file added data/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added data/__pycache__/base_dataset.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/base_dataset.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/base_dataset.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/base_dataset.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
166 changes: 166 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod


class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""

def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
self.root = opt.dataroot

@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser

@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0

@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass

@abstractmethod
def set_epoch(self, epoch):
"""Set the epoch for this dataset
Parameters:
epoch - - the epoch number
"""
pass


def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w

x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))

flip = random.random() > 0.5

return {'crop_pos': (x, y), 'flip': flip}


def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

if 'crop' in opt.preprocess:
if params is None:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

if not opt.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

if convert:
transform_list += [transforms.ToTensor()]
if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)


def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img

__print_size_warning(ow, oh, w, h)
return img.resize((w, h), method)


def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)


def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img


def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img


def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True
Loading

0 comments on commit 5bda989

Please sign in to comment.