Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
hxngiee authored Sep 3, 2023
1 parent 132d1e1 commit 24039e4
Show file tree
Hide file tree
Showing 46 changed files with 6,455 additions and 0 deletions.
399 changes: 399 additions & 0 deletions LICENSE.txt

Large diffs are not rendered by default.

Binary file added data/src/001.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/src/002.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/src/003.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/targ/001.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/targ/002.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/targ/003.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from optimization.image_editor import ImageEditor
from optimization.arguments import get_arguments

if __name__ == "__main__":
args = get_arguments()
image_editor = ImageEditor(args)
image_editor.edit_image_by_prompt()
4 changes: 4 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .models import ArcMarginModel
from .models import ResNet
from .models import IRBlock
from .models import SEBlock
Empty file.
23 changes: 23 additions & 0 deletions models/gaze_estimation/gaze_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch.nn as nn
from models.gaze_estimation.models.eyenet import EyeNet
from torchvision import transforms

class Gaze_estimator(nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device('cpu')
self.checkpoint = torch.load('checkpoints/GazeEstimator.pt', map_location=self.device)
self.nstack = self.checkpoint['nstack']
self.nfeatures = self.checkpoint['nfeatures']
self.nlandmarks = self.checkpoint['nlandmarks']
self.eyenet = EyeNet(nstack=self.nstack, nfeatures=self.nfeatures, nlandmarks=self.nlandmarks).to(self.device)
self.eyenet.load_state_dict(self.checkpoint['model_state_dict'])
self.t = transforms.Resize((96, 160))

def forward(self, image):
heatmaps_pred, landmarks_pred, gaze_pred = self.eyenet.forward(self.t(image))
return gaze_pred

if __name__ == '__main__':
model = Gaze_estimator()
Empty file.
109 changes: 109 additions & 0 deletions models/gaze_estimation/models/eyenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from torch import nn
from models.gaze_estimation.models.layers import Conv, Hourglass, Pool, Residual
from models.gaze_estimation.models.losses import HeatmapLoss
from models.gaze_estimation.util.softargmax import softargmax2d


class Merge(nn.Module):
def __init__(self, x_dim, y_dim):
super(Merge, self).__init__()
self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)

def forward(self, x):
return self.conv(x)


class EyeNet(nn.Module):
def __init__(self, nstack, nfeatures, nlandmarks, bn=False, increase=0, **kwargs):
super(EyeNet, self).__init__()

self.img_w = 160
self.img_h = 96
self.nstack = nstack
self.nfeatures = nfeatures
self.nlandmarks = nlandmarks

self.heatmap_w = self.img_w / 2
self.heatmap_h = self.img_h / 2

self.nstack = nstack
self.pre = nn.Sequential(
Conv(1, 64, 7, 1, bn=True, relu=True),
Residual(64, 128),
Pool(2, 2),
Residual(128, 128),
Residual(128, nfeatures)
)

self.pre2 = nn.Sequential(
Conv(nfeatures, 64, 7, 2, bn=True, relu=True),
Residual(64, 128),
Pool(2, 2),
Residual(128, 128),
Residual(128, nfeatures)
)

self.hgs = nn.ModuleList([
nn.Sequential(
Hourglass(4, nfeatures, bn, increase),
) for i in range(nstack)])

self.features = nn.ModuleList([
nn.Sequential(
Residual(nfeatures, nfeatures),
Conv(nfeatures, nfeatures, 1, bn=True, relu=True)
) for i in range(nstack)])

self.outs = nn.ModuleList([Conv(nfeatures, nlandmarks, 1, relu=False, bn=False) for i in range(nstack)])
self.merge_features = nn.ModuleList([Merge(nfeatures, nfeatures) for i in range(nstack - 1)])
self.merge_preds = nn.ModuleList([Merge(nlandmarks, nfeatures) for i in range(nstack - 1)])

self.gaze_fc1 = nn.Linear(in_features=int(nfeatures * self.img_w * self.img_h / 64 + nlandmarks*2), out_features=256)
self.gaze_fc2 = nn.Linear(in_features=256, out_features=2)

self.nstack = nstack
self.heatmapLoss = HeatmapLoss()
self.landmarks_loss = nn.MSELoss()
self.gaze_loss = nn.MSELoss()

def forward(self, imgs):
# imgs of size 1,ih,iw
x = imgs.unsqueeze(1)
x = self.pre(x)

gaze_x = self.pre2(x)
gaze_x = gaze_x.flatten(start_dim=1)

combined_hm_preds = []
for i in torch.arange(self.nstack):
hg = self.hgs[i](x)
feature = self.features[i](hg)
preds = self.outs[i](feature)
combined_hm_preds.append(preds)
if i < self.nstack - 1:
x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)

heatmaps_out = torch.stack(combined_hm_preds, 1)

# preds = N x nlandmarks * heatmap_w * heatmap_h
landmarks_out = softargmax2d(preds) # N x nlandmarks x 2

# Gaze
gaze = torch.cat((gaze_x, landmarks_out.flatten(start_dim=1)), dim=1)
gaze = self.gaze_fc1(gaze)
gaze = nn.functional.relu(gaze)
gaze = self.gaze_fc2(gaze)

return heatmaps_out, landmarks_out, gaze

def calc_loss(self, combined_hm_preds, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze):
combined_loss = []
for i in range(self.nstack):
combined_loss.append(self.heatmapLoss(combined_hm_preds[:, i, :], heatmaps))

heatmap_loss = torch.stack(combined_loss, dim=1)
landmarks_loss = self.landmarks_loss(landmarks_pred, landmarks)
gaze_loss = self.gaze_loss(gaze_pred, gaze)

return torch.sum(heatmap_loss), landmarks_loss, 1000 * gaze_loss
89 changes: 89 additions & 0 deletions models/gaze_estimation/models/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from torch import nn

Pool = nn.MaxPool2d


def batchnorm(x):
return nn.BatchNorm2d(x.size()[1])(x)


class Conv(nn.Module):
def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
super(Conv, self).__init__()
self.inp_dim = inp_dim
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
self.relu = None
self.bn = None
if relu:
self.relu = nn.ReLU()
if bn:
self.bn = nn.BatchNorm2d(out_dim)

def forward(self, x):
assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x


class Residual(nn.Module):
def __init__(self, inp_dim, out_dim):
super(Residual, self).__init__()
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2d(inp_dim)
self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
self.bn2 = nn.BatchNorm2d(int(out_dim/2))
self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
self.bn3 = nn.BatchNorm2d(int(out_dim/2))
self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
if inp_dim == out_dim:
self.need_skip = False
else:
self.need_skip = True

def forward(self, x):
if self.need_skip:
residual = self.skip_layer(x)
else:
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out += residual
return out


class Hourglass(nn.Module):
def __init__(self, n, f, bn=None, increase=0):
super(Hourglass, self).__init__()
nf = f + increase
self.up1 = Residual(f, f)
# Lower branch
self.pool1 = Pool(2, 2)
self.low1 = Residual(f, nf)
self.n = n
# Recursive hourglass
if self.n > 1:
self.low2 = Hourglass(n-1, nf, bn=bn)
else:
self.low2 = Residual(nf, nf)
self.low3 = Residual(nf, f)

def forward(self, x):
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = nn.functional.interpolate(low3, x.shape[2:], mode='bilinear')
return up1 + up2
21 changes: 21 additions & 0 deletions models/gaze_estimation/models/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch


class HeatmapLoss(torch.nn.Module):
def __init__(self):
super(HeatmapLoss, self).__init__()

def forward(self, pred, gt):
loss = ((pred - gt)**2)
loss = torch.mean(loss, dim=(1, 2, 3))
return loss


class AngularError(torch.nn.Module):
def __init__(self):
super(AngularError, self).__init__()

def forward(self, gaze_pred, gaze):
loss = ((gaze_pred - gaze)**2)
loss = torch.mean(loss, dim=(1, 2, 3))
return loss
Empty file.
20 changes: 20 additions & 0 deletions models/gaze_estimation/util/eye_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from util.eye_sample import EyeSample


class EyePrediction():
def __init__(self, eye_sample: EyeSample, landmarks, gaze):
self._eye_sample = eye_sample
self._landmarks = landmarks
self._gaze = gaze

@property
def eye_sample(self):
return self._eye_sample

@property
def landmarks(self):
return self._landmarks

@property
def gaze(self):
return self._gaze
27 changes: 27 additions & 0 deletions models/gaze_estimation/util/eye_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

class EyeSample:
def __init__(self, orig_img, img, is_left, transform_inv, estimated_radius):
self._orig_img = orig_img.copy()
self._img = img.copy()
self._is_left = is_left
self._transform_inv = transform_inv
self._estimated_radius = estimated_radius
@property
def orig_img(self):
return self._orig_img

@property
def img(self):
return self._img

@property
def is_left(self):
return self._is_left

@property
def transform_inv(self):
return self._transform_inv

@property
def estimated_radius(self):
return self._estimated_radius
76 changes: 76 additions & 0 deletions models/gaze_estimation/util/gaze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Utility methods for gaze angle and error calculations."""
import cv2 as cv
import numpy as np

def pitchyaw_to_vector(pitchyaws):
r"""Convert given yaw (:math:`\theta`) and pitch (:math:`\phi`) angles to unit gaze vectors.
Args:
pitchyaws (:obj:`numpy.array`): yaw and pitch angles :math:`(n\times 2)` in radians.
Returns:
:obj:`numpy.array` of shape :math:`(n\times 3)` with 3D vectors per row.
"""
n = pitchyaws.shape[0]
sin = np.sin(pitchyaws)
cos = np.cos(pitchyaws)
out = np.empty((n, 3))
out[:, 0] = np.multiply(cos[:, 0], sin[:, 1])
out[:, 1] = sin[:, 0]
out[:, 2] = np.multiply(cos[:, 0], cos[:, 1])
return out


def vector_to_pitchyaw(vectors):
r"""Convert given gaze vectors to yaw (:math:`\theta`) and pitch (:math:`\phi`) angles.
Args:
vectors (:obj:`numpy.array`): gaze vectors in 3D :math:`(n\times 3)`.
Returns:
:obj:`numpy.array` of shape :math:`(n\times 2)` with values in radians.
"""
n = vectors.shape[0]
out = np.empty((n, 2))
vectors = np.divide(vectors, np.linalg.norm(vectors, axis=1).reshape(n, 1))
out[:, 0] = np.arcsin(vectors[:, 1]) # theta
out[:, 1] = np.arctan2(vectors[:, 0], vectors[:, 2]) # phi
return out

radians_to_degrees = 180.0 / np.pi


def angular_error(a, b):
"""Calculate angular error (via cosine similarity)."""
a = pitchyaw_to_vector(a) if a.shape[1] == 2 else a
b = pitchyaw_to_vector(b) if b.shape[1] == 2 else b

ab = np.sum(np.multiply(a, b), axis=1)
a_norm = np.linalg.norm(a, axis=1)
b_norm = np.linalg.norm(b, axis=1)

# Avoid zero-values (to avoid NaNs)
a_norm = np.clip(a_norm, a_min=1e-7, a_max=None)
b_norm = np.clip(b_norm, a_min=1e-7, a_max=None)

similarity = np.divide(ab, np.multiply(a_norm, b_norm))

return np.arccos(similarity) * radians_to_degrees


def mean_angular_error(a, b):
"""Calculate mean angular error (via cosine similarity)."""
return np.mean(angular_error(a, b))


def draw_gaze(image_in, eye_pos, pitchyaw, length=40.0, thickness=2, color=(0, 0, 255)):
"""Draw gaze angle on given image with a given eye positions."""
image_out = image_in
if len(image_out.shape) == 2 or image_out.shape[2] == 1:
image_out = cv.cvtColor(image_out, cv.COLOR_GRAY2BGR)
dx = -length * np.sin(pitchyaw[1])
dy = length * np.sin(pitchyaw[0])
cv.arrowedLine(image_out, tuple(np.round(eye_pos).astype(np.int32)),
tuple(np.round([eye_pos[0] + dx, eye_pos[1] + dy]).astype(int)), color,
thickness, cv.LINE_AA, tipLength=0.2)
return image_out
Loading

0 comments on commit 24039e4

Please sign in to comment.