-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit abb6c24
Showing
31 changed files
with
1,713 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Portrait Mating implementation in UNet with PyTorch. | ||
|
||
**Segmentation Demo Result:** | ||
![Segmentation](https://user-images.githubusercontent.com/30276789/76141416-03521900-609f-11ea-95e7-80d7ecf83760.png) | ||
**Matting Demo Result:** | ||
![Matting](https://user-images.githubusercontent.com/30276789/76142315-81b2b900-60a7-11ea-934d-35a00e50eda2.png) | ||
For the convenience of demonstration, I built the API service through Flask, and finally deployed it on WeChat Mini Program. | ||
The code part of the WeChat applet is in here [portrait-matting-wechat](https://github.com/leijue222/portrait-matting-wechat). | ||
|
||
## Dependencies | ||
|
||
- Python 3.6 | ||
- PyTorch >= 1.1.0 | ||
- Torchvision >= 0.3.0 | ||
- Flask 1.1.1 | ||
- future 0.18.2 | ||
- matplotlib 3.1.3 | ||
- numpy 1.16.0 | ||
- Pillow 6.2.0 | ||
- protobuf 3.11.3 | ||
- tensorboard 1.14.0 | ||
- tqdm==4.42.1 | ||
|
||
## Data | ||
This model was trained from scratch with 18000 images (data augmentation by 2000images) | ||
Training dataset was from [Deep Automatic Portrait Matting](http://www.cse.cuhk.edu.hk/leojia/projects/automatting/index.html). | ||
Your can download in baidu cloud [http://pan.baidu.com/s/1dE14537](http://pan.baidu.com/s/1dE14537). Password: ndg8 | ||
**For academic communication only, if there is a quote, please inform the original author!** | ||
|
||
We augment the number of images by perturbing them withrotation and scaling. Four rotation angles{−45◦,−22◦,22◦,45◦}and four scales{0.6,0.8,1.2,1.5}are used. We also apply four different Gamma transforms toincrease color variation. The Gamma values are{0.5,0.8,1.2,1.5}. After thesetransforms, we have 18K training images. | ||
|
||
## Run locally | ||
**Note : Use Python 3** | ||
### Prediction | ||
|
||
You can easily test the output masks on your images via the CLI. | ||
|
||
To predict a single image and save it: | ||
|
||
```bash | ||
$ python predict.py -i image.jpg -o output.jpg | ||
``` | ||
|
||
To predict a multiple images and show them without saving them: | ||
|
||
```bash | ||
$ python predict.py -i image1.jpg image2.jpg --viz --no-save | ||
``` | ||
|
||
```shell script | ||
> python predict.py -h | ||
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] | ||
[--output INPUT [INPUT ...]] [--viz] [--no-save] | ||
[--mask-threshold MASK_THRESHOLD] [--scale SCALE] | ||
|
||
Predict masks from input images | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
--model FILE, -m FILE | ||
Specify the file in which the model is stored | ||
(default: MODEL.pth) | ||
--input INPUT [INPUT ...], -i INPUT [INPUT ...] | ||
filenames of input images (default: None) | ||
--output INPUT [INPUT ...], -o INPUT [INPUT ...] | ||
Filenames of ouput images (default: None) | ||
--viz, -v Visualize the images as they are processed (default: | ||
False) | ||
--no-save, -n Do not save the output masks (default: False) | ||
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD | ||
Minimum probability value to consider a mask pixel | ||
white (default: 0.5) | ||
--scale SCALE, -s SCALE | ||
Scale factor for the input images (default: 0.5) | ||
``` | ||
You can specify which model file to use with `--model MODEL.pth`. | ||
### Training | ||
```shell script | ||
> python train.py -h | ||
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL] | ||
|
||
Train the UNet on images and target masks | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
-e E, --epochs E Number of epochs (default: 5) | ||
-b [B], --batch-size [B] | ||
Batch size (default: 1) | ||
-l [LR], --learning-rate [LR] | ||
Learning rate (default: 0.1) | ||
-f LOAD, --load LOAD Load model from a .pth file (default: False) | ||
-s SCALE, --scale SCALE | ||
Downscaling factor of the images (default: 0.5) | ||
-v VAL, --validation VAL | ||
Percent of the data that is used as validation (0-100) | ||
(default: 10.0) | ||
|
||
``` | ||
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. | ||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. | ||
### Start API service | ||
```bash | ||
$ python app.py | ||
``` | ||
Then you can use the model through the API | ||
## Run on server | ||
1. Install virtual environment | ||
2. Install gunicorn in a virtual environment | ||
3. Proxy through nginx | ||
## Notes on memory | ||
```bash | ||
$ python train.py -e 200 -b 1 -l 0.1 -s 0.5 -v 15.0 | ||
``` | ||
The model has be trained from scratch on a RTX2080Ti 11GB. | ||
18,000 training dataset, running for 4 days + | ||
## Thanks | ||
The birth of this project is inseparable from the following projects: | ||
- **[Flask](https://github.com/pallets/flask):The Python micro framework for building web applications** | ||
- **[Pytorch-UNet](https://github.com/milesial/Pytorch-UNet):PyTorch implementation of the U-Net for image semantic segmentation with high quality images** | ||
--- | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# -*- coding:utf-8 -*- | ||
import os | ||
import re | ||
import flask | ||
from PIL import Image | ||
|
||
from web.flask_config import input_path, output_path, bg_path | ||
from web.flask_utils import change_channels_to_rgb, merge_image_name, tid_maker | ||
from web.matting import process | ||
|
||
app = flask.Flask(__name__) | ||
model = None | ||
use_gpu = False | ||
|
||
|
||
def segmentation(image_path): | ||
names = re.findall(r'[^\\/:*?"<>|\r\n]+$', image_path) | ||
mask_path = output_path + names[0] | ||
change_channels_to_rgb(image_path) | ||
cmd_predict = "python predict.py -i {0} -o {1}".format(image_path, mask_path) | ||
os.system(cmd_predict) | ||
mask = Image.open(mask_path) | ||
img = Image.open(image_path) | ||
if mask.size != img.size: | ||
w, h = img.size | ||
mask = mask.resize((w, h)) | ||
mask.save(mask_path) | ||
return mask_path | ||
|
||
|
||
@app.route("/api/wechat/upload", methods=["POST"]) | ||
def upload(): | ||
data = {"success": False} | ||
file = flask.request.files['file'] | ||
file_type = flask.request.form.get("type", type=str) | ||
file_name = file_type + tid_maker() + '.png' | ||
file_path = './static/' + file_type + '/' + file_name | ||
file.save(file_path) | ||
file_photo = file_type + "/" + file_name + '.png' | ||
data["filePath"] = flask.url_for('static', _external=True, filename=file_photo) | ||
data["fileName"] = file_name | ||
data["success"] = True | ||
return flask.jsonify(data) | ||
|
||
|
||
@app.route("/api/wechat/matting", methods=["POST"]) | ||
def wechat_matting(): | ||
data = {"success": False} | ||
im_name = flask.request.form.get("im", type=str) | ||
bg_name = flask.request.form.get("bg", type=str) | ||
im_path = input_path + im_name | ||
segmentation(im_path) | ||
process(im_name, bg_name) | ||
|
||
merge_name = merge_image_name(im_name, bg_name) + '-merge.png' | ||
merge_photo = "merge/" + merge_name | ||
data["result"] = flask.url_for('static', _external=True, filename=merge_photo) | ||
data["success"] = True | ||
return flask.jsonify(data) | ||
|
||
|
||
@app.route("/api/seg", methods=["POST"]) | ||
def seg(): | ||
data = {"success": False} | ||
if flask.request.method == "POST": | ||
if flask.request.files.get("image"): | ||
input_image = flask.request.files['image'] | ||
file_name = input_image.filename | ||
image_path = input_path + file_name | ||
input_image.save(image_path) | ||
segmentation(image_path) | ||
|
||
in_photo = "input/" + file_name | ||
data["upload"] = flask.url_for('static', _external=True, filename=in_photo) | ||
out_photo = "output/" + file_name | ||
data["result"] = flask.url_for('static', _external=True, filename=out_photo) | ||
data["success"] = True | ||
return flask.jsonify(data) | ||
|
||
|
||
|
||
@app.route("/api/matting", methods=["POST"]) | ||
def matting(): | ||
data = {"success": False} | ||
print("files:", flask.request.files) | ||
bg_image = flask.request.files['bg'] | ||
im_image = flask.request.files['im'] | ||
bg_name = bg_image.filename | ||
im_name = im_image.filename | ||
im_path = input_path + im_name | ||
bg_image.save(bg_path + bg_name) | ||
im_image.save(im_path) | ||
|
||
segmentation(im_path) | ||
process(im_name, bg_name) | ||
|
||
merge_name = merge_image_name(im_name, bg_name) + '-merge.png' | ||
merge_photo = "merge/" + merge_name | ||
data["result"] = flask.url_for('static', _external=True, filename=merge_photo) | ||
data["success"] = True | ||
return flask.jsonify(data) | ||
|
||
|
||
@app.route("/api/clean", methods=["POST"]) | ||
def clean(): | ||
data = {"success": False} | ||
minutes = flask.request.form.get("minutes", type=str, default=30) | ||
if flask.request.method == "POST": | ||
cmd_in = "find ./static/input/ -type f -mmin +" + minutes + " -exec rm {} \;" | ||
cmd_out = "find ./static/output/ -type f -mmin +" + minutes + " -exec rm {} \;" | ||
cmd_bg = "find ./static/bg/ -type f -mmin +" + minutes + " -exec rm {} \;" | ||
cmd_merge = "find ./static/merge/ -type f -mmin +" + minutes + " -exec rm {} \;" | ||
os.system(cmd_in) | ||
os.system(cmd_out) | ||
os.system(cmd_bg) | ||
os.system(cmd_merge) | ||
data["success"] = True | ||
return flask.jsonify(data) | ||
|
||
|
||
if __name__ == '__main__': | ||
# im_name = '2.jpg' | ||
# bg_name = 'bg3.jpg' | ||
# img, alpha, fg, bg = process(im_name, bg_name) | ||
app.run(host='127.0.0.1', port=5000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
put images in this directory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
put targer masks in this directory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import torch | ||
from torch.autograd import Function | ||
|
||
|
||
class DiceCoeff(Function): | ||
"""Dice coeff for individual examples""" | ||
|
||
def forward(self, input, target): | ||
self.save_for_backward(input, target) | ||
eps = 0.0001 | ||
self.inter = torch.dot(input.view(-1), target.view(-1)) | ||
self.union = torch.sum(input) + torch.sum(target) + eps | ||
|
||
t = (2 * self.inter.float() + eps) / self.union.float() | ||
return t | ||
|
||
# This function has only a single output, so it gets only one gradient | ||
def backward(self, grad_output): | ||
|
||
input, target = self.saved_variables | ||
grad_input = grad_target = None | ||
|
||
if self.needs_input_grad[0]: | ||
grad_input = grad_output * 2 * (target * self.union - self.inter) \ | ||
/ (self.union * self.union) | ||
if self.needs_input_grad[1]: | ||
grad_target = None | ||
|
||
return grad_input, grad_target | ||
|
||
|
||
def dice_coeff(input, target): | ||
"""Dice coeff for batches""" | ||
if input.is_cuda: | ||
s = torch.FloatTensor(1).cuda().zero_() | ||
else: | ||
s = torch.FloatTensor(1).zero_() | ||
|
||
for i, c in enumerate(zip(input, target)): | ||
s = s + DiceCoeff().forward(c[0], c[1]) | ||
|
||
return s / (i + 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from tqdm import tqdm | ||
|
||
from dice_loss import dice_coeff | ||
|
||
|
||
def eval_net(net, loader, device, n_val): | ||
"""Evaluation without the densecrf with the dice coefficient""" | ||
net.eval() | ||
tot = 0 | ||
|
||
with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: | ||
for batch in loader: | ||
imgs = batch['image'] | ||
true_masks = batch['mask'] | ||
|
||
imgs = imgs.to(device=device, dtype=torch.float32) | ||
mask_type = torch.float32 if net.n_classes == 1 else torch.long | ||
true_masks = true_masks.to(device=device, dtype=mask_type) | ||
|
||
mask_pred = net(imgs) | ||
|
||
for true_mask, pred in zip(true_masks, mask_pred): | ||
pred = (pred > 0.5).float() | ||
if net.n_classes > 1: | ||
tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() | ||
else: | ||
tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item() | ||
pbar.update(imgs.shape[0]) | ||
|
||
return tot / n_val |
Oops, something went wrong.