Skip to content

Commit

Permalink
Added struct2depth model
Browse files Browse the repository at this point in the history
  • Loading branch information
aneliaangelova committed Nov 19, 2018
1 parent 4d4eb85 commit 811fa20
Show file tree
Hide file tree
Showing 15 changed files with 3,864 additions and 0 deletions.
1 change: 1 addition & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
/research/slim/ @sguada @nathansilberman
/research/steve/ @buckman-google
/research/street/ @theraysmith
/research/struct2depth/ @aneliaangelova
/research/swivel/ @waterson
/research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick
/research/tcn/ @coreylynch @sermanet
Expand Down
1 change: 1 addition & 0 deletions research/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ request.
- [slim](slim): image classification models in TF-Slim.
- [street](street): identify the name of a street (in France) from an image
using a Deep RNN.
- [struct2depth](struct2depth): unsupervised learning of depth and ego-motion.
- [swivel](swivel): the Swivel algorithm for generating word embeddings.
- [syntaxnet](syntaxnet): neural models of natural language syntax.
- [tcn](tcn): Self-supervised representation learning from multi-view video.
Expand Down
1 change: 1 addition & 0 deletions research/struct2depth/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package(default_visibility = ["//visibility:public"])
147 changes: 147 additions & 0 deletions research/struct2depth/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# struct2depth

This a method for unsupervised learning of depth and egomotion from monocular video, achieving new state-of-the-art results on both tasks by explicitly modeling 3D object motion, performing on-line refinement and improving quality for moving objects by novel loss formulations. It will appear in the following paper:

**V. Casser, S. Pirk, R. Mahjourian, A. Angelova, Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos, AAAI Conference on Artificial Intelligence, 2019**
https://arxiv.org/pdf/1811.06152.pdf

This code is implemented and supported by Vincent Casser (git username: VincentCa) and Anelia Angelova (git username: AneliaAngelova). Please contact [email protected] for questions.

Project website: https://sites.google.com/view/struct2depth.

## Quick start: Running training

Before running training, run gen_data_* script for the respective dataset in order to generate the data in the appropriate format for KITTI or Cityscapes. It is assumed that motion masks are already generated and stored as images.
Models are trained from an Imagenet pretrained model.

```shell

ckpt_dir="your/checkpoint/folder"
data_dir="KITTI_SEQ2_LR/" # Set for KITTI
data_dir="CITYSCAPES_SEQ2_LR/" # Set for Cityscapes
imagenet_ckpt="resnet_pretrained/model.ckpt"

python train.py \
--logtostderr \
--checkpoint_dir $ckpt_dir \
--data_dir $data_dir \
--architecture resnet \
--imagenet_ckpt $imagenet_ckpt \
--imagenet_norm true \
--joint_encoder false
```



## Running depth/egomotion inference on an image folder

KITTI is trained on the raw image data (resized to 416 x 128), but inputs are standardized before feeding them, and Cityscapes images are cropped using the following cropping parameters: (192, 1856, 256, 768). If using a different crop, it is likely that additional training is necessary. Therefore, please follow the inference example shown below when using one of the models. The right choice might depend on a variety of factors. For example, if a checkpoint should be used for odometry, be aware that for improved odometry on motion models, using segmentation masks could be advantageous (setting *use_masks=true* for inference). On the other hand, all models can be used for single-frame depth estimation without any additional information.


```shell

input_dir="your/image/folder"
output_dir="your/output/folder"
model_checkpoint="your/model/checkpoint"

python inference.py \
--logtostderr \
--file_extension png \
--depth \
--egomotion true \
--input_dir $input_dir \
--output_dir $output_dir \
--model_ckpt $model_checkpoint
```

Note that the egomotion prediction expects the files in the input directory to be a consecutive sequence, and that sorting the filenames alphabetically is putting them in the right order.

One can also run inference on KITTI by providing

```shell
--input_list_file ~/kitti-raw-uncompressed/test_files_eigen.txt
```

and on Cityscapes by passing

```shell
--input_list_file CITYSCAPES_FULL/test_files_cityscapes.txt
```

instead of *input_dir*.
Alternatively inference can also be ran on pre-processed images.



## Running on-line refinement

On-line refinement is executed on top of an existing inference folder, so make sure to run regular inference first. Then you can run the on-line fusion procedure as follows:

```shell

prediction_dir="some/prediction/dir"
model_ckpt="checkpoints/checkpoints_baseline/model-199160"
handle_motion="false"
size_constraint_weight="0" # This must be zero when not handling motion.

# If running on KITTI, set as follows:
data_dir="KITTI_SEQ2_LR_EIGEN/"
triplet_list_file="$data_dir/test_files_eigen_triplets.txt"
triplet_list_file_remains="$data_dir/test_files_eigen_triplets_remains.txt"
ft_name="kitti"

# If running on Cityscapes, set as follows:
data_dir="CITYSCAPES_SEQ2_LR_TEST/" # Set for Cityscapes
triplet_list_file="/CITYSCAPES_SEQ2_LR_TEST/test_files_cityscapes_triplets.txt"
triplet_list_file_remains="CITYSCAPES_SEQ2_LR_TEST/test_files_cityscapes_triplets_remains.txt"
ft_name="cityscapes"

python optimize.py \
--logtostderr \
--output_dir $prediction_dir \
--data_dir $data_dir \
--triplet_list_file $triplet_list_file \
--triplet_list_file_remains $triplet_list_file_remains \
--ft_name $ft_name \
--model_ckpt $model_ckpt \
--file_extension png \
--handle_motion $handle_motion \
--size_constraint_weight $size_constraint_weight
```



## Running evaluation

```shell

prediction_dir="some/prediction/dir"

# Use these settings for KITTI:
eval_list_file="KITTI_FULL/kitti-raw-uncompressed/test_files_eigen.txt"
eval_crop="garg"
eval_mode="kitti"

# Use these settings for Cityscapes:
eval_list_file="CITYSCAPES_FULL/test_files_cityscapes.txt"
eval_crop="none"
eval_mode="cityscapes"

python evaluate.py \
--logtostderr \
--prediction_dir $prediction_dir \
--eval_list_file $eval_list_file \
--eval_crop $eval_crop \
--eval_mode $eval_mode
```



## Credits

This code is implemented and supported by Vincent Casser and Anelia Angelova and can be found at
https://sites.google.com/view/struct2depth.
The core implementation is derived from [https://github.com/tensorflow/models/tree/master/research/vid2depth)](https://github.com/tensorflow/models/tree/master/research/vid2depth)
by [Reza Mahjourian]([email protected]), which in turn is based on [SfMLearner
(https://github.com/tinghuiz/SfMLearner)](https://github.com/tinghuiz/SfMLearner)
by [Tinghui Zhou](https://github.com/tinghuiz).
54 changes: 54 additions & 0 deletions research/struct2depth/alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Common utilities for data pre-processing, e.g. matching moving object across frames."""

import numpy as np

def compute_overlap(mask1, mask2):
# Use IoU here.
return np.sum(mask1 & mask2)/np.sum(mask1 | mask2)

def align(seg_img1, seg_img2, seg_img3, threshold_same=0.3):
res_img1 = np.zeros_like(seg_img1)
res_img2 = np.zeros_like(seg_img2)
res_img3 = np.zeros_like(seg_img3)
remaining_objects2 = list(np.unique(seg_img2.flatten()))
remaining_objects3 = list(np.unique(seg_img3.flatten()))
for seg_id in np.unique(seg_img1):
# See if we can find correspondences to seg_id in seg_img2.
max_overlap2 = float('-inf')
max_segid2 = -1
for seg_id2 in remaining_objects2:
overlap = compute_overlap(seg_img1==seg_id, seg_img2==seg_id2)
if overlap>max_overlap2:
max_overlap2 = overlap
max_segid2 = seg_id2
if max_overlap2 > threshold_same:
max_overlap3 = float('-inf')
max_segid3 = -1
for seg_id3 in remaining_objects3:
overlap = compute_overlap(seg_img2==max_segid2, seg_img3==seg_id3)
if overlap>max_overlap3:
max_overlap3 = overlap
max_segid3 = seg_id3
if max_overlap3 > threshold_same:
res_img1[seg_img1==seg_id] = seg_id
res_img2[seg_img2==max_segid2] = seg_id
res_img3[seg_img3==max_segid3] = seg_id
remaining_objects2.remove(max_segid2)
remaining_objects3.remove(max_segid3)
return res_img1, res_img2, res_img3
158 changes: 158 additions & 0 deletions research/struct2depth/gen_data_city.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@

# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

""" Offline data generation for the Cityscapes dataset."""

import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import cv2
import os, glob

import alignment
from alignment import compute_overlap
from alignment import align


SKIP = 2
WIDTH = 416
HEIGHT = 128
SUB_FOLDER = 'train'
INPUT_DIR = '/usr/local/google/home/anelia/struct2depth/CITYSCAPES_FULL/'
OUTPUT_DIR = '/usr/local/google/home/anelia/struct2depth/CITYSCAPES_Processed/'

def crop(img, segimg, fx, fy, cx, cy):
# Perform center cropping, preserving 50% vertically.
middle_perc = 0.50
left = 1 - middle_perc
half = left / 2
a = img[int(img.shape[0]*(half)):int(img.shape[0]*(1-half)), :]
aseg = segimg[int(segimg.shape[0]*(half)):int(segimg.shape[0]*(1-half)), :]
cy /= (1 / middle_perc)

# Resize to match target height while preserving aspect ratio.
wdt = int((float(HEIGHT)*a.shape[1]/a.shape[0]))
x_scaling = float(wdt)/a.shape[1]
y_scaling = float(HEIGHT)/a.shape[0]
b = cv2.resize(a, (wdt, HEIGHT))
bseg = cv2.resize(aseg, (wdt, HEIGHT))

# Adjust intrinsics.
fx*=x_scaling
fy*=y_scaling
cx*=x_scaling
cy*=y_scaling

# Perform center cropping horizontally.
remain = b.shape[1] - WIDTH
cx /= (b.shape[1] / WIDTH)
c = b[:, int(remain/2):b.shape[1]-int(remain/2)]
cseg = bseg[:, int(remain/2):b.shape[1]-int(remain/2)]

return c, cseg, fx, fy, cx, cy


def run_all():
dir_name=INPUT_DIR + '/leftImg8bit_sequence/' + SUB_FOLDER + '/*'
print('Processing directory', dir_name)
for location in glob.glob(INPUT_DIR + '/leftImg8bit_sequence/' + SUB_FOLDER + '/*'):
location_name = os.path.basename(location)
print('Processing location', location_name)
files = sorted(glob.glob(location + '/*.png'))
files = [file for file in files if '-seg.png' not in file]
# Break down into sequences
sequences = {}
seq_nr = 0
last_seq = ''
last_imgnr = -1

for i in range(len(files)):
seq = os.path.basename(files[i]).split('_')[1]
nr = int(os.path.basename(files[i]).split('_')[2])
if seq!=last_seq or last_imgnr+1!=nr:
seq_nr+=1
last_imgnr = nr
last_seq = seq
if not seq_nr in sequences:
sequences[seq_nr] = []
sequences[seq_nr].append(files[i])

for (k,v) in sequences.items():
print('Processing sequence', k, 'with', len(v), 'elements...')
output_dir = OUTPUT_DIR + '/' + location_name + '_' + str(k)
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
files = sorted(v)
triplet = []
seg_triplet = []
ct = 1

# Find applicable intrinsics.
for j in range(len(files)):
osegname = os.path.basename(files[j]).split('_')[1]
oimgnr = os.path.basename(files[j]).split('_')[2]
applicable_intrinsics = INPUT_DIR + '/camera/' + SUB_FOLDER + '/' + location_name + '/' + location_name + '_' + osegname + '_' + oimgnr + '_camera.json'
# Get the intrinsics for one of the file of the sequence.
if os.path.isfile(applicable_intrinsics):
f = open(applicable_intrinsics, 'r')
lines = f.readlines()
f.close()
lines = [line.rstrip() for line in lines]

fx = float(lines[11].split(': ')[1].replace(',', ''))
fy = float(lines[12].split(': ')[1].replace(',', ''))
cx = float(lines[13].split(': ')[1].replace(',', ''))
cy = float(lines[14].split(': ')[1].replace(',', ''))

for j in range(0, len(files), SKIP):
img = cv2.imread(files[j])
segimg = cv2.imread(files[j].replace('.png', '-seg.png'))

smallimg, segimg, fx_this, fy_this, cx_this, cy_this = crop(img, segimg, fx, fy, cx, cy)
triplet.append(smallimg)
seg_triplet.append(segimg)
if len(triplet)==3:
cmb = np.hstack(triplet)
align1, align2, align3 = align(seg_triplet[0], seg_triplet[1], seg_triplet[2])
cmb_seg = np.hstack([align1, align2, align3])
cv2.imwrite(os.path.join(output_dir, str(ct).zfill(10) + '.png'), cmb)
cv2.imwrite(os.path.join(output_dir, str(ct).zfill(10) + '-fseg.png'), cmb_seg)
f = open(os.path.join(output_dir, str(ct).zfill(10) + '_cam.txt'), 'w')
f.write(str(fx_this) + ',0.0,' + str(cx_this) + ',0.0,' + str(fy_this) + ',' + str(cy_this) + ',0.0,0.0,1.0')
f.close()
del triplet[0]
del seg_triplet[0]
ct+=1

# Create file list for training. Be careful as it collects and includes all files recursively.
fn = open(OUTPUT_DIR + '/' + SUB_FOLDER + '.txt', 'w')
for f in glob.glob(OUTPUT_DIR + '/*/*.png'):
if '-seg.png' in f or '-fseg.png' in f:
continue
folder_name = f.split('/')[-2]
img_name = f.split('/')[-1].replace('.png', '')
fn.write(folder_name + ' ' + img_name + '\n')
fn.close()


def main(_):
run_all()


if __name__ == '__main__':
app.run(main)
Loading

0 comments on commit 811fa20

Please sign in to comment.