-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
46 changed files
with
6,455 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from optimization.image_editor import ImageEditor | ||
from optimization.arguments import get_arguments | ||
|
||
if __name__ == "__main__": | ||
args = get_arguments() | ||
image_editor = ImageEditor(args) | ||
image_editor.edit_image_by_prompt() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .models import ArcMarginModel | ||
from .models import ResNet | ||
from .models import IRBlock | ||
from .models import SEBlock |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
import torch.nn as nn | ||
from models.gaze_estimation.models.eyenet import EyeNet | ||
from torchvision import transforms | ||
|
||
class Gaze_estimator(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.device = torch.device('cpu') | ||
self.checkpoint = torch.load('checkpoints/GazeEstimator.pt', map_location=self.device) | ||
self.nstack = self.checkpoint['nstack'] | ||
self.nfeatures = self.checkpoint['nfeatures'] | ||
self.nlandmarks = self.checkpoint['nlandmarks'] | ||
self.eyenet = EyeNet(nstack=self.nstack, nfeatures=self.nfeatures, nlandmarks=self.nlandmarks).to(self.device) | ||
self.eyenet.load_state_dict(self.checkpoint['model_state_dict']) | ||
self.t = transforms.Resize((96, 160)) | ||
|
||
def forward(self, image): | ||
heatmaps_pred, landmarks_pred, gaze_pred = self.eyenet.forward(self.t(image)) | ||
return gaze_pred | ||
|
||
if __name__ == '__main__': | ||
model = Gaze_estimator() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
from torch import nn | ||
from models.gaze_estimation.models.layers import Conv, Hourglass, Pool, Residual | ||
from models.gaze_estimation.models.losses import HeatmapLoss | ||
from models.gaze_estimation.util.softargmax import softargmax2d | ||
|
||
|
||
class Merge(nn.Module): | ||
def __init__(self, x_dim, y_dim): | ||
super(Merge, self).__init__() | ||
self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
class EyeNet(nn.Module): | ||
def __init__(self, nstack, nfeatures, nlandmarks, bn=False, increase=0, **kwargs): | ||
super(EyeNet, self).__init__() | ||
|
||
self.img_w = 160 | ||
self.img_h = 96 | ||
self.nstack = nstack | ||
self.nfeatures = nfeatures | ||
self.nlandmarks = nlandmarks | ||
|
||
self.heatmap_w = self.img_w / 2 | ||
self.heatmap_h = self.img_h / 2 | ||
|
||
self.nstack = nstack | ||
self.pre = nn.Sequential( | ||
Conv(1, 64, 7, 1, bn=True, relu=True), | ||
Residual(64, 128), | ||
Pool(2, 2), | ||
Residual(128, 128), | ||
Residual(128, nfeatures) | ||
) | ||
|
||
self.pre2 = nn.Sequential( | ||
Conv(nfeatures, 64, 7, 2, bn=True, relu=True), | ||
Residual(64, 128), | ||
Pool(2, 2), | ||
Residual(128, 128), | ||
Residual(128, nfeatures) | ||
) | ||
|
||
self.hgs = nn.ModuleList([ | ||
nn.Sequential( | ||
Hourglass(4, nfeatures, bn, increase), | ||
) for i in range(nstack)]) | ||
|
||
self.features = nn.ModuleList([ | ||
nn.Sequential( | ||
Residual(nfeatures, nfeatures), | ||
Conv(nfeatures, nfeatures, 1, bn=True, relu=True) | ||
) for i in range(nstack)]) | ||
|
||
self.outs = nn.ModuleList([Conv(nfeatures, nlandmarks, 1, relu=False, bn=False) for i in range(nstack)]) | ||
self.merge_features = nn.ModuleList([Merge(nfeatures, nfeatures) for i in range(nstack - 1)]) | ||
self.merge_preds = nn.ModuleList([Merge(nlandmarks, nfeatures) for i in range(nstack - 1)]) | ||
|
||
self.gaze_fc1 = nn.Linear(in_features=int(nfeatures * self.img_w * self.img_h / 64 + nlandmarks*2), out_features=256) | ||
self.gaze_fc2 = nn.Linear(in_features=256, out_features=2) | ||
|
||
self.nstack = nstack | ||
self.heatmapLoss = HeatmapLoss() | ||
self.landmarks_loss = nn.MSELoss() | ||
self.gaze_loss = nn.MSELoss() | ||
|
||
def forward(self, imgs): | ||
# imgs of size 1,ih,iw | ||
x = imgs.unsqueeze(1) | ||
x = self.pre(x) | ||
|
||
gaze_x = self.pre2(x) | ||
gaze_x = gaze_x.flatten(start_dim=1) | ||
|
||
combined_hm_preds = [] | ||
for i in torch.arange(self.nstack): | ||
hg = self.hgs[i](x) | ||
feature = self.features[i](hg) | ||
preds = self.outs[i](feature) | ||
combined_hm_preds.append(preds) | ||
if i < self.nstack - 1: | ||
x = x + self.merge_preds[i](preds) + self.merge_features[i](feature) | ||
|
||
heatmaps_out = torch.stack(combined_hm_preds, 1) | ||
|
||
# preds = N x nlandmarks * heatmap_w * heatmap_h | ||
landmarks_out = softargmax2d(preds) # N x nlandmarks x 2 | ||
|
||
# Gaze | ||
gaze = torch.cat((gaze_x, landmarks_out.flatten(start_dim=1)), dim=1) | ||
gaze = self.gaze_fc1(gaze) | ||
gaze = nn.functional.relu(gaze) | ||
gaze = self.gaze_fc2(gaze) | ||
|
||
return heatmaps_out, landmarks_out, gaze | ||
|
||
def calc_loss(self, combined_hm_preds, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze): | ||
combined_loss = [] | ||
for i in range(self.nstack): | ||
combined_loss.append(self.heatmapLoss(combined_hm_preds[:, i, :], heatmaps)) | ||
|
||
heatmap_loss = torch.stack(combined_loss, dim=1) | ||
landmarks_loss = self.landmarks_loss(landmarks_pred, landmarks) | ||
gaze_loss = self.gaze_loss(gaze_pred, gaze) | ||
|
||
return torch.sum(heatmap_loss), landmarks_loss, 1000 * gaze_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from torch import nn | ||
|
||
Pool = nn.MaxPool2d | ||
|
||
|
||
def batchnorm(x): | ||
return nn.BatchNorm2d(x.size()[1])(x) | ||
|
||
|
||
class Conv(nn.Module): | ||
def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True): | ||
super(Conv, self).__init__() | ||
self.inp_dim = inp_dim | ||
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True) | ||
self.relu = None | ||
self.bn = None | ||
if relu: | ||
self.relu = nn.ReLU() | ||
if bn: | ||
self.bn = nn.BatchNorm2d(out_dim) | ||
|
||
def forward(self, x): | ||
assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim) | ||
x = self.conv(x) | ||
if self.bn is not None: | ||
x = self.bn(x) | ||
if self.relu is not None: | ||
x = self.relu(x) | ||
return x | ||
|
||
|
||
class Residual(nn.Module): | ||
def __init__(self, inp_dim, out_dim): | ||
super(Residual, self).__init__() | ||
self.relu = nn.ReLU() | ||
self.bn1 = nn.BatchNorm2d(inp_dim) | ||
self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False) | ||
self.bn2 = nn.BatchNorm2d(int(out_dim/2)) | ||
self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False) | ||
self.bn3 = nn.BatchNorm2d(int(out_dim/2)) | ||
self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False) | ||
self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False) | ||
if inp_dim == out_dim: | ||
self.need_skip = False | ||
else: | ||
self.need_skip = True | ||
|
||
def forward(self, x): | ||
if self.need_skip: | ||
residual = self.skip_layer(x) | ||
else: | ||
residual = x | ||
out = self.bn1(x) | ||
out = self.relu(out) | ||
out = self.conv1(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
out = self.conv2(out) | ||
out = self.bn3(out) | ||
out = self.relu(out) | ||
out = self.conv3(out) | ||
out += residual | ||
return out | ||
|
||
|
||
class Hourglass(nn.Module): | ||
def __init__(self, n, f, bn=None, increase=0): | ||
super(Hourglass, self).__init__() | ||
nf = f + increase | ||
self.up1 = Residual(f, f) | ||
# Lower branch | ||
self.pool1 = Pool(2, 2) | ||
self.low1 = Residual(f, nf) | ||
self.n = n | ||
# Recursive hourglass | ||
if self.n > 1: | ||
self.low2 = Hourglass(n-1, nf, bn=bn) | ||
else: | ||
self.low2 = Residual(nf, nf) | ||
self.low3 = Residual(nf, f) | ||
|
||
def forward(self, x): | ||
up1 = self.up1(x) | ||
pool1 = self.pool1(x) | ||
low1 = self.low1(pool1) | ||
low2 = self.low2(low1) | ||
low3 = self.low3(low2) | ||
up2 = nn.functional.interpolate(low3, x.shape[2:], mode='bilinear') | ||
return up1 + up2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
|
||
class HeatmapLoss(torch.nn.Module): | ||
def __init__(self): | ||
super(HeatmapLoss, self).__init__() | ||
|
||
def forward(self, pred, gt): | ||
loss = ((pred - gt)**2) | ||
loss = torch.mean(loss, dim=(1, 2, 3)) | ||
return loss | ||
|
||
|
||
class AngularError(torch.nn.Module): | ||
def __init__(self): | ||
super(AngularError, self).__init__() | ||
|
||
def forward(self, gaze_pred, gaze): | ||
loss = ((gaze_pred - gaze)**2) | ||
loss = torch.mean(loss, dim=(1, 2, 3)) | ||
return loss |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from util.eye_sample import EyeSample | ||
|
||
|
||
class EyePrediction(): | ||
def __init__(self, eye_sample: EyeSample, landmarks, gaze): | ||
self._eye_sample = eye_sample | ||
self._landmarks = landmarks | ||
self._gaze = gaze | ||
|
||
@property | ||
def eye_sample(self): | ||
return self._eye_sample | ||
|
||
@property | ||
def landmarks(self): | ||
return self._landmarks | ||
|
||
@property | ||
def gaze(self): | ||
return self._gaze |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
class EyeSample: | ||
def __init__(self, orig_img, img, is_left, transform_inv, estimated_radius): | ||
self._orig_img = orig_img.copy() | ||
self._img = img.copy() | ||
self._is_left = is_left | ||
self._transform_inv = transform_inv | ||
self._estimated_radius = estimated_radius | ||
@property | ||
def orig_img(self): | ||
return self._orig_img | ||
|
||
@property | ||
def img(self): | ||
return self._img | ||
|
||
@property | ||
def is_left(self): | ||
return self._is_left | ||
|
||
@property | ||
def transform_inv(self): | ||
return self._transform_inv | ||
|
||
@property | ||
def estimated_radius(self): | ||
return self._estimated_radius |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Utility methods for gaze angle and error calculations.""" | ||
import cv2 as cv | ||
import numpy as np | ||
|
||
def pitchyaw_to_vector(pitchyaws): | ||
r"""Convert given yaw (:math:`\theta`) and pitch (:math:`\phi`) angles to unit gaze vectors. | ||
Args: | ||
pitchyaws (:obj:`numpy.array`): yaw and pitch angles :math:`(n\times 2)` in radians. | ||
Returns: | ||
:obj:`numpy.array` of shape :math:`(n\times 3)` with 3D vectors per row. | ||
""" | ||
n = pitchyaws.shape[0] | ||
sin = np.sin(pitchyaws) | ||
cos = np.cos(pitchyaws) | ||
out = np.empty((n, 3)) | ||
out[:, 0] = np.multiply(cos[:, 0], sin[:, 1]) | ||
out[:, 1] = sin[:, 0] | ||
out[:, 2] = np.multiply(cos[:, 0], cos[:, 1]) | ||
return out | ||
|
||
|
||
def vector_to_pitchyaw(vectors): | ||
r"""Convert given gaze vectors to yaw (:math:`\theta`) and pitch (:math:`\phi`) angles. | ||
Args: | ||
vectors (:obj:`numpy.array`): gaze vectors in 3D :math:`(n\times 3)`. | ||
Returns: | ||
:obj:`numpy.array` of shape :math:`(n\times 2)` with values in radians. | ||
""" | ||
n = vectors.shape[0] | ||
out = np.empty((n, 2)) | ||
vectors = np.divide(vectors, np.linalg.norm(vectors, axis=1).reshape(n, 1)) | ||
out[:, 0] = np.arcsin(vectors[:, 1]) # theta | ||
out[:, 1] = np.arctan2(vectors[:, 0], vectors[:, 2]) # phi | ||
return out | ||
|
||
radians_to_degrees = 180.0 / np.pi | ||
|
||
|
||
def angular_error(a, b): | ||
"""Calculate angular error (via cosine similarity).""" | ||
a = pitchyaw_to_vector(a) if a.shape[1] == 2 else a | ||
b = pitchyaw_to_vector(b) if b.shape[1] == 2 else b | ||
|
||
ab = np.sum(np.multiply(a, b), axis=1) | ||
a_norm = np.linalg.norm(a, axis=1) | ||
b_norm = np.linalg.norm(b, axis=1) | ||
|
||
# Avoid zero-values (to avoid NaNs) | ||
a_norm = np.clip(a_norm, a_min=1e-7, a_max=None) | ||
b_norm = np.clip(b_norm, a_min=1e-7, a_max=None) | ||
|
||
similarity = np.divide(ab, np.multiply(a_norm, b_norm)) | ||
|
||
return np.arccos(similarity) * radians_to_degrees | ||
|
||
|
||
def mean_angular_error(a, b): | ||
"""Calculate mean angular error (via cosine similarity).""" | ||
return np.mean(angular_error(a, b)) | ||
|
||
|
||
def draw_gaze(image_in, eye_pos, pitchyaw, length=40.0, thickness=2, color=(0, 0, 255)): | ||
"""Draw gaze angle on given image with a given eye positions.""" | ||
image_out = image_in | ||
if len(image_out.shape) == 2 or image_out.shape[2] == 1: | ||
image_out = cv.cvtColor(image_out, cv.COLOR_GRAY2BGR) | ||
dx = -length * np.sin(pitchyaw[1]) | ||
dy = length * np.sin(pitchyaw[0]) | ||
cv.arrowedLine(image_out, tuple(np.round(eye_pos).astype(np.int32)), | ||
tuple(np.round([eye_pos[0] + dx, eye_pos[1] + dy]).astype(int)), color, | ||
thickness, cv.LINE_AA, tipLength=0.2) | ||
return image_out |
Oops, something went wrong.