Skip to content

Commit

Permalink
add colored val preds
Browse files Browse the repository at this point in the history
  • Loading branch information
brdav committed Jul 19, 2022
1 parent 16a623c commit 824f086
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Before running the code, download and extract the corresponding datasets to the
</details>

<details>
<summary>DarkZurich</summary>
<summary>Dark Zurich</summary>

Download Dark_Zurich_train_anon.zip, Dark_Zurich_val_anon.zip, and Dark_Zurich_test_anon_withoutGt.zip from [here](https://www.trace.ethz.ch/publications/2019/GCMA_UIoU/) and extract them to `$DATA_DIR/DarkZurich`.

Expand All @@ -113,7 +113,7 @@ Before running the code, download and extract the corresponding datasets to the
</details>

<details>
<summary>NighttimeDriving</summary>
<summary>Nighttime Driving</summary>

Download NighttimeDrivingTest.zip from [here](http://people.ee.ethz.ch/~daid/NightDriving/) and extract it to `$DATA_DIR/NighttimeDrivingTest`.

Expand Down Expand Up @@ -230,6 +230,13 @@ We provide pretrained models of both UDA and alignment networks.

Note that the UAWarpC checkpoint is needed to train Refign. To avoid config file edits, save it to `./pretrained_models/`.

### Qualitative Refign Predictions

To facilitate qualitative comparisons, validation set predictions of Refign can be directly downloaded:
- [Refign on ACDC val](https://data.vision.ee.ethz.ch/brdavid/refign/colored_preds_val_ACDC.zip)
- [Refign on Dark Zurich val](https://data.vision.ee.ethz.ch/brdavid/refign/colored_preds_val_DarkZurich.zip)
- [Refign on RobotCar val](https://data.vision.ee.ethz.ch/brdavid/refign/colored_preds_val_RobotCar.zip)

### Refign Training

Make sure to first download the trained UAWarpC model with the link provided above.
Expand Down
15 changes: 15 additions & 0 deletions helpers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import os

import pytorch_lightning as pl
from PIL import Image

palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
zero_pad = 256 * 3 - len(palette)
for i in range(zero_pad):
palette.append(0)


def resolve_ckpt_dir(trainer: pl.Trainer):
Expand All @@ -25,3 +33,10 @@ def resolve_ckpt_dir(trainer: pl.Trainer):
ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path)

return ckpt_path


def colorize_mask(mask):
assert isinstance(mask, Image.Image)
new_mask = mask.convert('P')
new_mask.putpalette(palette)
return new_mask
7 changes: 6 additions & 1 deletion models/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from helpers.matching_utils import (
estimate_probability_of_confidence_interval_of_mixture_density, warp)
from helpers.metrics import MyMetricCollection
from helpers.utils import resolve_ckpt_dir
from helpers.utils import colorize_mask, resolve_ckpt_dir
from PIL import Image
from pytorch_lightning.utilities.cli import MODEL_REGISTRY, instantiate_class

Expand Down Expand Up @@ -215,8 +215,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
dataset_name = self.trainer.datamodule.predict_on[dataloader_idx]
save_dir = os.path.join(os.path.dirname(
resolve_ckpt_dir(self.trainer)), 'preds', dataset_name)
col_save_dir = os.path.join(os.path.dirname(
resolve_ckpt_dir(self.trainer)), 'color_preds', dataset_name)
if self.trainer.is_global_zero:
os.makedirs(save_dir, exist_ok=True)
os.makedirs(col_save_dir, exist_ok=True)
img_names = batch['filename']
x = batch['image']
orig_size = self.trainer.datamodule.predict_ds[dataloader_idx].orig_dims
Expand All @@ -226,6 +229,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
arr = pred.cpu().numpy()
image = Image.fromarray(arr.astype(np.uint8))
image.save(os.path.join(save_dir, im_name))
col_image = colorize_mask(image)
col_image.save(os.path.join(col_save_dir, im_name))

def forward(self, x, out_size=None, return_feats=False):
feats = self.backbone(x)
Expand Down

0 comments on commit 824f086

Please sign in to comment.