Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
819598 committed Jul 10, 2020
0 parents commit abb6c24
Show file tree
Hide file tree
Showing 31 changed files with 1,713 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions .idea/portrait-matting-unet-flask.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

674 changes: 674 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

Binary file added MODEL.pth
Binary file not shown.
134 changes: 134 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Portrait Mating implementation in UNet with PyTorch.

**Segmentation Demo Result:**
![Segmentation](https://user-images.githubusercontent.com/30276789/76141416-03521900-609f-11ea-95e7-80d7ecf83760.png)
**Matting Demo Result:**
![Matting](https://user-images.githubusercontent.com/30276789/76142315-81b2b900-60a7-11ea-934d-35a00e50eda2.png)
For the convenience of demonstration, I built the API service through Flask, and finally deployed it on WeChat Mini Program.
The code part of the WeChat applet is in here [portrait-matting-wechat](https://github.com/leijue222/portrait-matting-wechat).

## Dependencies

- Python 3.6
- PyTorch >= 1.1.0
- Torchvision >= 0.3.0
- Flask 1.1.1
- future 0.18.2
- matplotlib 3.1.3
- numpy 1.16.0
- Pillow 6.2.0
- protobuf 3.11.3
- tensorboard 1.14.0
- tqdm==4.42.1

## Data
This model was trained from scratch with 18000 images (data augmentation by 2000images)
Training dataset was from [Deep Automatic Portrait Matting](http://www.cse.cuhk.edu.hk/leojia/projects/automatting/index.html).
Your can download in baidu cloud [http://pan.baidu.com/s/1dE14537](http://pan.baidu.com/s/1dE14537). Password: ndg8
**For academic communication only, if there is a quote, please inform the original author!**

We augment the number of images by perturbing them withrotation and scaling. Four rotation angles{−45◦,−22◦,22◦,45◦}and four scales{0.6,0.8,1.2,1.5}are used. We also apply four different Gamma transforms toincrease color variation. The Gamma values are{0.5,0.8,1.2,1.5}. After thesetransforms, we have 18K training images.

## Run locally
**Note : Use Python 3**
### Prediction

You can easily test the output masks on your images via the CLI.

To predict a single image and save it:

```bash
$ python predict.py -i image.jpg -o output.jpg
```

To predict a multiple images and show them without saving them:

```bash
$ python predict.py -i image1.jpg image2.jpg --viz --no-save
```

```shell script
> python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
[--output INPUT [INPUT ...]] [--viz] [--no-save]
[--mask-threshold MASK_THRESHOLD] [--scale SCALE]

Predict masks from input images

optional arguments:
-h, --help show this help message and exit
--model FILE, -m FILE
Specify the file in which the model is stored
(default: MODEL.pth)
--input INPUT [INPUT ...], -i INPUT [INPUT ...]
filenames of input images (default: None)
--output INPUT [INPUT ...], -o INPUT [INPUT ...]
Filenames of ouput images (default: None)
--viz, -v Visualize the images as they are processed (default:
False)
--no-save, -n Do not save the output masks (default: False)
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
Minimum probability value to consider a mask pixel
white (default: 0.5)
--scale SCALE, -s SCALE
Scale factor for the input images (default: 0.5)
```
You can specify which model file to use with `--model MODEL.pth`.
### Training
```shell script
> python train.py -h
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]

Train the UNet on images and target masks

optional arguments:
-h, --help show this help message and exit
-e E, --epochs E Number of epochs (default: 5)
-b [B], --batch-size [B]
Batch size (default: 1)
-l [LR], --learning-rate [LR]
Learning rate (default: 0.1)
-f LOAD, --load LOAD Load model from a .pth file (default: False)
-s SCALE, --scale SCALE
Downscaling factor of the images (default: 0.5)
-v VAL, --validation VAL
Percent of the data that is used as validation (0-100)
(default: 10.0)

```
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively.
### Start API service
```bash
$ python app.py
```
Then you can use the model through the API
## Run on server
1. Install virtual environment
2. Install gunicorn in a virtual environment
3. Proxy through nginx
## Notes on memory
```bash
$ python train.py -e 200 -b 1 -l 0.1 -s 0.5 -v 15.0
```
The model has be trained from scratch on a RTX2080Ti 11GB.
18,000 training dataset, running for 4 days +
## Thanks
The birth of this project is inseparable from the following projects:
- **[Flask](https://github.com/pallets/flask):The Python micro framework for building web applications**
- **[Pytorch-UNet](https://github.com/milesial/Pytorch-UNet):PyTorch implementation of the U-Net for image semantic segmentation with high quality images**
---
125 changes: 125 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding:utf-8 -*-
import os
import re
import flask
from PIL import Image

from web.flask_config import input_path, output_path, bg_path
from web.flask_utils import change_channels_to_rgb, merge_image_name, tid_maker
from web.matting import process

app = flask.Flask(__name__)
model = None
use_gpu = False


def segmentation(image_path):
names = re.findall(r'[^\\/:*?"<>|\r\n]+$', image_path)
mask_path = output_path + names[0]
change_channels_to_rgb(image_path)
cmd_predict = "python predict.py -i {0} -o {1}".format(image_path, mask_path)
os.system(cmd_predict)
mask = Image.open(mask_path)
img = Image.open(image_path)
if mask.size != img.size:
w, h = img.size
mask = mask.resize((w, h))
mask.save(mask_path)
return mask_path


@app.route("/api/wechat/upload", methods=["POST"])
def upload():
data = {"success": False}
file = flask.request.files['file']
file_type = flask.request.form.get("type", type=str)
file_name = file_type + tid_maker() + '.png'
file_path = './static/' + file_type + '/' + file_name
file.save(file_path)
file_photo = file_type + "/" + file_name + '.png'
data["filePath"] = flask.url_for('static', _external=True, filename=file_photo)
data["fileName"] = file_name
data["success"] = True
return flask.jsonify(data)


@app.route("/api/wechat/matting", methods=["POST"])
def wechat_matting():
data = {"success": False}
im_name = flask.request.form.get("im", type=str)
bg_name = flask.request.form.get("bg", type=str)
im_path = input_path + im_name
segmentation(im_path)
process(im_name, bg_name)

merge_name = merge_image_name(im_name, bg_name) + '-merge.png'
merge_photo = "merge/" + merge_name
data["result"] = flask.url_for('static', _external=True, filename=merge_photo)
data["success"] = True
return flask.jsonify(data)


@app.route("/api/seg", methods=["POST"])
def seg():
data = {"success": False}
if flask.request.method == "POST":
if flask.request.files.get("image"):
input_image = flask.request.files['image']
file_name = input_image.filename
image_path = input_path + file_name
input_image.save(image_path)
segmentation(image_path)

in_photo = "input/" + file_name
data["upload"] = flask.url_for('static', _external=True, filename=in_photo)
out_photo = "output/" + file_name
data["result"] = flask.url_for('static', _external=True, filename=out_photo)
data["success"] = True
return flask.jsonify(data)



@app.route("/api/matting", methods=["POST"])
def matting():
data = {"success": False}
print("files:", flask.request.files)
bg_image = flask.request.files['bg']
im_image = flask.request.files['im']
bg_name = bg_image.filename
im_name = im_image.filename
im_path = input_path + im_name
bg_image.save(bg_path + bg_name)
im_image.save(im_path)

segmentation(im_path)
process(im_name, bg_name)

merge_name = merge_image_name(im_name, bg_name) + '-merge.png'
merge_photo = "merge/" + merge_name
data["result"] = flask.url_for('static', _external=True, filename=merge_photo)
data["success"] = True
return flask.jsonify(data)


@app.route("/api/clean", methods=["POST"])
def clean():
data = {"success": False}
minutes = flask.request.form.get("minutes", type=str, default=30)
if flask.request.method == "POST":
cmd_in = "find ./static/input/ -type f -mmin +" + minutes + " -exec rm {} \;"
cmd_out = "find ./static/output/ -type f -mmin +" + minutes + " -exec rm {} \;"
cmd_bg = "find ./static/bg/ -type f -mmin +" + minutes + " -exec rm {} \;"
cmd_merge = "find ./static/merge/ -type f -mmin +" + minutes + " -exec rm {} \;"
os.system(cmd_in)
os.system(cmd_out)
os.system(cmd_bg)
os.system(cmd_merge)
data["success"] = True
return flask.jsonify(data)


if __name__ == '__main__':
# im_name = '2.jpg'
# bg_name = 'bg3.jpg'
# img, alpha, fg, bg = process(im_name, bg_name)
app.run(host='127.0.0.1', port=5000)
1 change: 1 addition & 0 deletions data/imgs/imgs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
put images in this directory
1 change: 1 addition & 0 deletions data/masks/mask.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
put targer masks in this directory
42 changes: 42 additions & 0 deletions dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from torch.autograd import Function


class DiceCoeff(Function):
"""Dice coeff for individual examples"""

def forward(self, input, target):
self.save_for_backward(input, target)
eps = 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps

t = (2 * self.inter.float() + eps) / self.union.float()
return t

# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):

input, target = self.saved_variables
grad_input = grad_target = None

if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union - self.inter) \
/ (self.union * self.union)
if self.needs_input_grad[1]:
grad_target = None

return grad_input, grad_target


def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = torch.FloatTensor(1).cuda().zero_()
else:
s = torch.FloatTensor(1).zero_()

for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])

return s / (i + 1)
32 changes: 32 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm

from dice_loss import dice_coeff


def eval_net(net, loader, device, n_val):
"""Evaluation without the densecrf with the dice coefficient"""
net.eval()
tot = 0

with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
for batch in loader:
imgs = batch['image']
true_masks = batch['mask']

imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)

mask_pred = net(imgs)

for true_mask, pred in zip(true_masks, mask_pred):
pred = (pred > 0.5).float()
if net.n_classes > 1:
tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
else:
tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()
pbar.update(imgs.shape[0])

return tot / n_val
Loading

0 comments on commit abb6c24

Please sign in to comment.