Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract clothes using 2d segmentation #72

Merged
merged 12 commits into from
Jun 13, 2022
30 changes: 28 additions & 2 deletions apps/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lib.dataset.mesh_util import load_checkpoint, update_mesh_shape_prior_losses, get_optim_grid_image, blend_rgb_norm, unwrap
from lib.common.config import cfg
from lib.common.render import query_color
from lib.common.cloth_extraction import extract_cloth

import logging
logging.getLogger("trimesh").setLevel(logging.ERROR)
Expand All @@ -54,6 +55,8 @@ def tensor2variable(tensor, device):
parser.add_argument('-in_dir', '--in_dir', type=str, default="../examples")
parser.add_argument('-out_dir', '--out_dir', type=str, default="../results")
parser.add_argument('-cfg', '--config', type=str, default="../configs/icon-filter.yaml")
parser.add_argument('-seg_dir', '--seg_dir', type=str, default=None)
# parser.add_argument('-cloth_dir', '--cloth_dir', type=str, default=None)

args = parser.parse_args()

Expand Down Expand Up @@ -85,6 +88,7 @@ def tensor2variable(tensor, device):

dataset_param = {
'image_dir': args.in_dir,
'seg_dir': args.seg_dir,
'has_det': True, # w/ or w/o detection
'hps_type': args.hps_type # pymaf/pare/pixie
}
Expand Down Expand Up @@ -281,6 +285,7 @@ def tensor2variable(tensor, device):
os.makedirs(os.path.join(args.out_dir, cfg.name, "obj"),
exist_ok=True)


if cfg.net.prior_type != 'pifu':
per_data_lst[0].save(os.path.join(args.out_dir, cfg.name,
f"gif/{data['name']}_smpl.gif"),
Expand Down Expand Up @@ -464,6 +469,27 @@ def tensor2variable(tensor, device):
smpl_obj = trimesh.Trimesh(
in_tensor['smpl_verts'].detach().cpu()[0] *
torch.tensor([1.0, -1.0, 1.0]),
in_tensor['smpl_faces'].detach().cpu()[0])
in_tensor['smpl_faces'].detach().cpu()[0],
process=False,
maintains_order=True
)
smpl_obj.export(
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl.obj")
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl.obj")

if not (args.seg_dir is None):
os.makedirs(os.path.join(args.out_dir, cfg.name, "clothes"),
exist_ok=True)
for seg in data['segmentations']:
## These matrices work for PyMaf, not sure about the other hps type
K = np.array([[ 1.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 1.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, -0.5000, 0.0000],
[-0.0000, -0.0000, 0.5000, 1.0000]]).T

R = np.array([[-1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., -1.]])

t = np.array([[ -0., -0., 100.]])
clothing_obj = extract_cloth(recon_obj, seg, K, R, t, smpl_obj)
clothing_obj.export(os.path.join(args.out_dir, cfg.name, "clothes", f"{data['name']}_{seg['type'].replace(' ', '_')}.obj"))
157 changes: 157 additions & 0 deletions lib/common/cloth_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import numpy as np
import json
import itertools
import trimesh
from matplotlib.path import Path
from collections import Counter
from sklearn.neighbors import KNeighborsClassifier

def load_segmentation(path, shape):
"""
Get a segmentation mask for a given image
Arguments:
path: path to the segmentation json file
shape: shape of the output mask
Returns:
Returns a segmentation mask
"""
with open(path) as json_file:
dict = json.load(json_file)
segmentations = []
for key, val in dict.items():
if not key.startswith('item'):
continue

# Each item can have multiple polygons. Combine them to one
# segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
# segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)

coordinates = []
for segmentation_coord in val['segmentation']:
# The format before is [x1,y1, x2, y2, ....]
x = segmentation_coord[::2]
y = segmentation_coord[1::2]
xy = np.vstack((x, y)).T
coordinates.append(xy)

segmentations.append({'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})

return segmentations

def smpl_to_recon_labels(recon, smpl, k=1):
"""
Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
Arguments:
recon: trimesh object (fully clothed model)
shape: trimesh object (smpl model)
k: number of nearest neighbours to use
Returns:
Returns a dictionary containing the bodypart and the corresponding indices
"""
smpl_vert_segmentation = json.load(open('../lib/common/smpl_vert_segmentation.json'))
n = smpl.vertices.shape[0]
y = np.array([None] * n)
for key, val in smpl_vert_segmentation.items():
y[val] = key

classifier = KNeighborsClassifier(n_neighbors=1)
classifier.fit(smpl.vertices, y)

y_pred = classifier.predict(recon.vertices)

recon_labels = {}
for key in smpl_vert_segmentation.keys():
recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int))

return recon_labels

def extract_cloth(recon, segmentation, K, R, t, smpl = None):
"""
Extract a portion of a mesh using 2d segmentation coordinates
Arguments:
recon: fully clothed mesh
seg_coord: segmentation coordinates in 2D (NDC)
K: intrinsic matrix of the projection
R: rotation matrix of the projection
t: translation vector of the projection
Returns:
Returns a submesh using the segmentation coordinates
"""
seg_coord = segmentation['coord_normalized']
mesh = trimesh.Trimesh(recon.vertices, recon.faces)
extrinsic = np.zeros((3,4))
extrinsic[:3, :3] = R
extrinsic[:,3] = t
P = K[:3, :3] @ extrinsic

P_inv = np.linalg.pinv(P)

### Each segmentation can contain multiple polygons
### We need to check them separately
points_so_far = []
faces = recon.faces
for polygon in seg_coord:
n = len(polygon)
coords_h = np.hstack((polygon, np.ones((n,1))))
# Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
XYZ = P_inv @ coords_h[:,:, None]
XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
XYZ = XYZ[:, :3] / XYZ[:,3, None]

p = Path(XYZ[:, :2])

grid = p.contains_points(recon.vertices[:,:2])
indeces = np.argwhere(grid == True)
points_so_far += list(indeces.flatten())

if smpl is not None:
num_verts = recon.vertices.shape[0]
recon_labels = smpl_to_recon_labels(recon, smpl)
body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head', 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
type = segmentation['type_id']

# Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
# https://github.com/switchablenorms/DeepFashion2
# Short sleeve clothes
if type == 1 or type == 3 or type == 10:
body_parts_to_remove += ['leftForeArm', 'rightForeArm']
# No sleeves at all or lower body clothes
elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
body_parts_to_remove += ['leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
# Shorts
elif type == 7:
body_parts_to_remove += ['leftLeg', 'rightLeg', 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']


verts_to_remove = list(itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove]))

label_mask = np.zeros(num_verts, dtype=bool)
label_mask[verts_to_remove] = True

seg_mask = np.zeros(num_verts, dtype=bool)
seg_mask[points_so_far] = True

# Remove points that belong to other bodyparts
# If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))

combine_mask = np.zeros(num_verts, dtype=bool)
combine_mask[points_so_far] = True
combine_mask[extra_verts_to_remove] = False

all_indices = np.argwhere(combine_mask == True).flatten()

i_x = np.where(np.in1d(faces[:,0], all_indices))[0]
i_y = np.where(np.in1d(faces[:,1], all_indices))[0]
i_z = np.where(np.in1d(faces[:,2], all_indices))[0]

faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
mask = np.zeros(len(recon.faces), dtype=bool)
mask[faces_to_keep] = True

mesh.update_faces(mask)
mesh.remove_unreferenced_vertices()

mesh.rezero()

return mesh
Loading