Skip to content

Commit

Permalink
release
Browse files Browse the repository at this point in the history
  • Loading branch information
WXinlong committed Mar 26, 2023
1 parent 6cea8ec commit e7a6e0d
Show file tree
Hide file tree
Showing 115 changed files with 15,145 additions and 6 deletions.
16 changes: 16 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
__pycache__
datasets/
toy_datasets/
models
models_inference
work_dirs/
wandb
datasets
.idea

.nfs*
*.pth
log.txt
log*.txt
demo/ood
*log.txt
59 changes: 53 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,21 @@

<sup>1</sup>[BAAI](https://www.baai.ac.cn/english.html), &nbsp; <sup>2</sup>[ZJU](https://www.zju.edu.cn/english/), &nbsp; <sup>3</sup>[PKU](https://english.pku.edu.cn/)


CVPR 2023


<br>

<image src="teaser.jpg" width="720px" />
<image src="docs/teaser.jpg" width="720px" />
<br>

</div>

<br>

We present Painter, a generalist model using an "image"-centric solution for in-context visual learning, that is, to redefine the output of core vision tasks as images, and specify task prompts as also images. With this idea, our training process is extremely simple, which performs standard masked image modeling on the stitch of input and output image pairs. This makes the model capable of performing tasks conditioned on visible image patches. Thus, during inference, we can adopt a pair of input and output images from the same task as the input condition, to indicate which task to perform. Examples of in-context inference are illustrated in the figure above, consisting of seven in-domain examples (seven rows at top) and three out-of-domain examples (three rows at bottom).

We present Painter, a generalist model using an "image"-centric solution for in-context visual learning, that is, to redefine the output of core vision tasks as images, and specify task prompts as also images. With this idea, our training process is extremely simple, which performs standard masked image modeling on the stitch of input and output image pairs. This makes the model capable of performing tasks conditioned on visible image patches. Thus, during inference, we can adopt a pair of input and output images from the same task as the input condition, to indicate which task to perform. Examples of in-context inference are illustrated in the figure above, consisting of seven in-domain examples (seven rows at top) and three out-of-domain examples (three rows at bottom).
Without bells and whistles, our generalist Painter can achieve competitive performance compared to well-established task-specific models, on seven representative vision tasks ranging from high-level visual understanding to low-level image processing.
Painter significantly outperforms recent generalist models on several challenging tasks.
Surprisingly, our model shows capabilities of completing out-of-domain tasks, which do not exist in the training data, such as open-category keypoint detection and object segmentation, validating the powerful task transferability of in-context learning.
In addition, Painter significantly outperforms recent generalist models on several challenging tasks.

[[Paper]](https://arxiv.org/abs/2212.02499)

Expand All @@ -42,6 +40,52 @@ Surprisingly, our model shows capabilities of completing out-of-domain tasks, wh
- even the tasks do not exist in the training data


## Installation
See [installation instructions](docs/INSTALL.md).

## Data
See [data instructions](docs/DATA.md).

We also provide [a toy training dataset](https://huggingface.co/BAAI/Painter/blob/main/toy_datasets.tar), with 10 samples from each required datasets. You can put it in `$Painter_ROOT/toy_datasets` and set `DATA_PATH=toy_datasets` in `$Painter_ROOT/train_painter_vit_large.sh` for toy experiments.

## Training
Download pre-trained MAE ViT-Large model from [here](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth) and update `path/to/mae_pretrain_vit_large.pth` in `$Painter_ROOT/train_painter_vit_large.sh`.

We use 8 nodes (<code>total_bsz = 8x8x32 = 2048</code>) for training:



```bash
bash train_painter_vit_large.sh
```

## Evaluation
See [evaluation instructions](docs/EVAL.md).

A pre-trained Painter is available at [🤗 Hugging Face Models](https://huggingface.co/BAAI/Painter/blob/main/painter_vit_large.pth). The results on various tasks are summarized below:


<table border="1" width="100%">
<tr align="center">
<!-- <th> Task </th> -->
<th colspan="3"> depth estimation </th><th colspan="1"> semantic seg. </th><th colspan="1">panoptic seg.</th><th colspan="1">keypoint det.</th> <th colspan="2"> denoising </th> <th colspan="2"> deraining </th> <th colspan="2"> enhance.</th>
</tr>
<tr align="center">
<!-- <th> Dataset </th> -->
<th colspan="3"> NYU v2 </th><th colspan="1"> ADE20k </th><th colspan="1"> COCO 2017 </th><th colspan="1"> COCO 2017 </th> <th colspan="2"> SIDD </th> <th colspan="2"> 5 datasets </th> <th colspan="2"> LoL </th>
</tr>
<tr align="center">
<!-- <th> Metric </th> -->
<th> RMSE </th> <th> A.Rel </th> <th> d1 </th> <th colspan="1"> mIoU </th><th colspan="1">PQ</th><th colspan="1">AP</th> <th> PSNR </th> <th> SSIM </th> <th> PSNR </th> <th> SSIM </th> <th> PSNR </th> <th> SSIM </th>
</tr>
<tr align="center">
<!-- <th> Painter </th> -->
<th> 0.288 </th> <th> 0.080 </th> <th> 0.950 </th> <th colspan="1"> 49.9 </th> <th> 43.4 </th> <th>72.1</th> <th> 38.66 </th> <th> 0.954 </th> <th> 29.42 </th> <th> 0.867 </th> <th> 22.34 </th> <th> 0.872 </th>
</tr>
</table>
<br>


## Citation

```
Expand All @@ -53,6 +97,9 @@ Surprisingly, our model shows capabilities of completing out-of-domain tasks, wh
}
```

## Acknowledgement
[MAE](https://github.com/facebookresearch/mae), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [detectron2](https://github.com/facebookresearch/detectron2), [Mask2Former](https://github.com/facebookresearch/Mask2Former), [bts](https://github.com/cleinc/bts), [mmcv](https://github.com/open-mmlab/mmcv), [mmdetetection](https://github.com/open-mmlab/mmdetection), [mmpose](https://github.com/open-mmlab/mmpose), [MIRNet](https://github.com/swz30/MIRNet), [MPRNet](https://github.com/swz30/MPRNet), and [Uformer](https://github.com/ZhendongWang6/Uformer).

## Contact

**We are hiring** at all levels at BAAI Vision Team, including full-time researchers, engineers and interns.
Expand Down
Empty file added data/__init__.py
Empty file.
145 changes: 145 additions & 0 deletions data/ade20k/gen_color_ade20k_sem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# --------------------------------------------------------
# Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499)
# Github source: https://github.com/baaivision/Painter
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Xinlong Wang, Wen Wang
# Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases
# --------------------------------------------------------'

import os
import glob
import argparse
import json
import tqdm
import sys
sys.path.insert(0, "data")

import numpy as np
from PIL import Image


def unique(ar, return_index=False, return_inverse=False, return_counts=False):
"copied from https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/master/mit_semseg/utils.py"
ar = np.asanyarray(ar).flatten()

optional_indices = return_index or return_inverse
optional_returns = optional_indices or return_counts

if ar.size == 0:
if not optional_returns:
ret = ar
else:
ret = (ar,)
if return_index:
ret += (np.empty(0, np.bool),)
if return_inverse:
ret += (np.empty(0, np.bool),)
if return_counts:
ret += (np.empty(0, np.intp),)
return ret
if optional_indices:
perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
aux = ar[perm]
else:
ar.sort()
aux = ar
flag = np.concatenate(([True], aux[1:] != aux[:-1]))

if not optional_returns:
ret = aux[flag]
else:
ret = (aux[flag],)
if return_index:
ret += (perm[flag],)
if return_inverse:
iflag = np.cumsum(flag) - 1
inv_idx = np.empty(ar.shape, dtype=np.intp)
inv_idx[perm] = iflag
ret += (inv_idx,)
if return_counts:
idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
ret += (np.diff(idx),)
return ret


def colorEncode(labelmap, colors, mode='RGB'):
"Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/master/mit_semseg/utils.py"
labelmap = labelmap.astype('int')
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
dtype=np.uint8)

for label in unique(labelmap):
if label <= 0:
continue
# note the color_index = class_index - 1
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
np.tile(np.array(colors[label-1], dtype=np.uint8), (labelmap.shape[0], labelmap.shape[1], 1))

if mode == 'BGR':
return labelmap_rgb[:, :, ::-1]
else:
return labelmap_rgb


def define_colors_per_location_mean_sep():
num_locations = 150
num_sep_per_channel = int(num_locations ** (1 / 3)) + 1 # 19
separation_per_channel = 256 // num_sep_per_channel

color_list = []
for location in range(num_locations):
num_seq_r = location // num_sep_per_channel ** 2
num_seq_g = (location % num_sep_per_channel ** 2) // num_sep_per_channel
num_seq_b = location % num_sep_per_channel
assert (num_seq_r <= num_sep_per_channel) and (num_seq_g <= num_sep_per_channel) \
and (num_seq_b <= num_sep_per_channel)

R = 255 - num_seq_r * separation_per_channel
G = 255 - num_seq_g * separation_per_channel
B = 255 - num_seq_b * separation_per_channel
assert (R < 256) and (G < 256) and (B < 256)
assert (R >= 0) and (G >= 0) and (B >= 0)
assert (R, G, B) not in color_list

color_list.append((R, G, B))
# print(location, (num_seq_r, num_seq_g, num_seq_b), (R, G, B))

return color_list


PALETTE = define_colors_per_location_mean_sep()


def get_args_parser():
parser = argparse.ArgumentParser('ADE20k semantic segmentation preparation', add_help=False)
parser.add_argument('--split', type=str, help='dataset split',
choices=['training', 'validation'], required=True)
return parser.parse_args()


if __name__ == '__main__':
args = get_args_parser()

image_dir = os.path.join("datasets/ade20k/images", args.split)
segm_dir = os.path.join("datasets/ade20k/annotations", args.split)
save_dir = os.path.join("datasets/ade20k/annotations_with_color", args.split)
if not os.path.exists(save_dir):
os.makedirs(save_dir)

color_list = define_colors_per_location_mean_sep()

segm_path_list = glob.glob(os.path.join(segm_dir, '*.png'))
for segm_path in tqdm.tqdm(segm_path_list):
# check files
file_name = os.path.basename(segm_path)
# in ade20k, images are jpegs, while segms are pngs
image_path = os.path.join(image_dir, file_name.replace('.png', '.jpg'))
assert os.path.isfile(segm_path)
assert os.path.isfile(image_path)

# paint colors on segm
segm = Image.open(segm_path)
segm_color = colorEncode(labelmap=np.array(segm), colors=color_list).astype(np.uint8)
segm_color = Image.fromarray(segm_color)
segm_color.save(os.path.join(save_dir, file_name))
47 changes: 47 additions & 0 deletions data/ade20k/gen_json_ade20k_sem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# --------------------------------------------------------
# Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499)
# Github source: https://github.com/baaivision/Painter
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Xinlong Wang, Wen Wang
# Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases
# --------------------------------------------------------'

import os
import glob
import json
import tqdm
import argparse


def get_args_parser():
parser = argparse.ArgumentParser('ADE20k semantic segmentation preparation', add_help=False)
parser.add_argument('--split', type=str, help='dataset split',
choices=['training', 'validation'], required=True)
parser.add_argument('--output_dir', type=str, help='path to output dir',
default='datasets/ade20k')
return parser.parse_args()


if __name__ == '__main__':
args = get_args_parser()

image_dir = os.path.join("datasets/ade20k/images", args.split)
annos_dir = os.path.join("datasets/ade20k/annotations_with_color", args.split)
save_path = os.path.join(args.output_dir, "ade20k_{}_image_semantic.json".format(args.split))

output_dict = []

image_path_list = glob.glob(os.path.join(image_dir, '*g'))
for image_path in tqdm.tqdm(image_path_list):
image_name = image_path.split('/')[-1].split('.')[0]
image_path = os.path.join(image_dir, image_name + '.jpg')
panoptic_path = os.path.join(annos_dir, image_name + '.png')
assert os.path.isfile(image_path)
assert os.path.isfile(panoptic_path)
pair_dict = {}
pair_dict["image_path"] = os.path.join("ade20k/images/{}/".format(args.split), image_name + ".jpg")
pair_dict["target_path"] = "ade20k/annotations_with_color/{}/".format(args.split) + image_name + ".png"
pair_dict["type"] = "ade20k_image2semantic"
output_dict.append(pair_dict)
json.dump(output_dict, open(save_path, 'w'))
Loading

0 comments on commit e7a6e0d

Please sign in to comment.