-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1cbf2cd
Showing
159 changed files
with
523,372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# TPNMS | ||
|
||
**Code and models for AAAI 2021 paper: *Temporal Pyramid Network for Pedestrian Trajectory Prediction with Multi-Supervision*** | ||
|
||
### Environment | ||
|
||
- Python 3.8 | ||
- pytorch 1.11.0 | ||
- cuda 11.3 | ||
- Ubuntu 20.04 | ||
- RTX 3090 | ||
- Please refer to the "requirements.txt" file for more details. | ||
|
||
### Usage | ||
To test the model, run: | ||
```bash | ||
scripts/evaluate_model.py | ||
``` | ||
To train the model, run: | ||
```bash | ||
scripts/train_TPN_P.py | ||
``` | ||
|
||
|
||
|
||
### Citation | ||
If you find this work useful in your research, please consider citing: | ||
``` | ||
@inproceedings{liang2021temporal, | ||
title={Temporal Pyramid Network for Pedestrian Trajectory Prediction with Multi-Supervision}, | ||
author={Liang, Rongqin and Li, Yuanman and Li, Xia and Tang, Yi and Zhou, Jiantao and Zou, Wenbin}, | ||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, | ||
volume={35}, | ||
number={3}, | ||
pages={2029--2037}, | ||
year={2021} | ||
} | ||
``` | ||
|
||
### Contact | ||
|
||
If you encounter any issue when running the code, please feel free to reach us either by creating a new issue in the github or by emailing | ||
|
||
+ [email protected] |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .trajectories import seq_collate_TPN, TrajectoryDatasetTPN |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from torch.utils.data import DataLoader | ||
from TPN.data.trajectories import TrajectoryDatasetTPN, seq_collate_TPN | ||
|
||
|
||
def data_loader_TPN(args, path): | ||
dset = TrajectoryDatasetTPN( | ||
path, | ||
obs_len=args.obs_len, | ||
pred_len=args.pred_len, | ||
skip=args.skip, | ||
delim=args.delim) | ||
|
||
loader = DataLoader( | ||
dset, | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
num_workers=args.loader_num_workers, | ||
collate_fn=seq_collate_TPN) | ||
return dset, loader | ||
|
||
|
||
def data_loader_TPN_test(args, path): | ||
dset = TrajectoryDatasetTPN( | ||
path, | ||
obs_len=args.obs_len, | ||
pred_len=args.pred_len, | ||
skip=args.skip, | ||
delim=args.delim) | ||
|
||
loader = DataLoader( | ||
dset, | ||
batch_size=args.batch_size, | ||
shuffle=False, | ||
num_workers=args.loader_num_workers, | ||
collate_fn=seq_collate_TPN) | ||
return dset, loader |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import torch | ||
import random | ||
|
||
def bce_loss(input, target): | ||
""" | ||
Numerically stable version of the binary cross-entropy loss function. | ||
As per https://github.com/pytorch/pytorch/issues/751 | ||
See the TensorFlow docs for a derivation of this formula: | ||
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits | ||
Input: | ||
- input: PyTorch Tensor of shape (N, ) giving scores. | ||
- target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets. | ||
Output: | ||
- A PyTorch Tensor containing the mean BCE loss over the minibatch of | ||
input data. | ||
""" | ||
neg_abs = -input.abs() | ||
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() | ||
return loss.mean() | ||
|
||
|
||
def gan_g_loss(scores_fake): | ||
""" | ||
Input: | ||
- scores_fake: Tensor of shape (N,) containing scores for fake samples | ||
Output: | ||
- loss: Tensor of shape (,) giving GAN generator loss | ||
""" | ||
y_fake = torch.ones_like(scores_fake) * random.uniform(0.7, 1.2) | ||
return bce_loss(scores_fake, y_fake) | ||
|
||
|
||
def gan_d_loss(scores_real, scores_fake): | ||
""" | ||
Input: | ||
- scores_real: Tensor of shape (N,) giving scores for real samples | ||
- scores_fake: Tensor of shape (N,) giving scores for fake samples | ||
Output: | ||
- loss: Tensor of shape (,) giving GAN discriminator loss | ||
""" | ||
y_real = torch.ones_like(scores_real) * random.uniform(0.7, 1.2) | ||
y_fake = torch.zeros_like(scores_fake) * random.uniform(0, 0.3) | ||
loss_real = bce_loss(scores_real, y_real) | ||
loss_fake = bce_loss(scores_fake, y_fake) | ||
return loss_real + loss_fake | ||
|
||
|
||
def l2_loss(pred_traj, pred_traj_gt, loss_mask=None, random=0, mode='average'): | ||
""" | ||
Input: | ||
- pred_traj: Tensor of shape (seq_len, batch, 2). Predicted trajectory. | ||
- pred_traj_gt: Tensor of shape (seq_len, batch, 2). Groud truth | ||
predictions. | ||
- loss_mask: Tensor of shape (batch, seq_len) | ||
- mode: Can be one of sum, average, raw | ||
Output: | ||
- loss: l2 loss depending on mode | ||
""" | ||
seq_len, batch, _ = pred_traj.size() | ||
# loss = (loss_mask.unsqueeze(dim=2) * | ||
# (pred_traj_gt.permute(1, 0, 2) - pred_traj.permute(1, 0, 2))**2) | ||
loss = ((pred_traj_gt.permute(1, 0, 2) - pred_traj.permute(1, 0, 2))**2) # 1.20 | ||
if mode == 'sum': | ||
return torch.sum(loss) | ||
elif mode == 'average': | ||
return torch.sum(loss) / torch.numel(loss_mask.data) | ||
elif mode == 'raw': | ||
return loss.sum(dim=2).sum(dim=1) | ||
|
||
|
||
def displacement_error(pred_traj, pred_traj_gt, consider_ped=None, mode='sum'): | ||
""" | ||
Input: | ||
- pred_traj: Tensor of shape (seq_len, batch, 2). Predicted trajectory. | ||
- pred_traj_gt: Tensor of shape (seq_len, batch, 2). Ground truth | ||
predictions. | ||
- consider_ped: Tensor of shape (batch) | ||
- mode: Can be one of sum, raw | ||
Output: | ||
- loss: gives the eculidian displacement error | ||
""" | ||
seq_len, _, _ = pred_traj.size() | ||
loss = pred_traj_gt.permute(1, 0, 2) - pred_traj.permute(1, 0, 2) | ||
loss = loss**2 | ||
if consider_ped is not None: | ||
loss = torch.sqrt(loss.sum(dim=2)).sum(dim=1) * consider_ped | ||
else: | ||
loss = torch.sqrt(loss.sum(dim=2)).sum(dim=1) | ||
if mode == 'sum': | ||
return torch.sum(loss) | ||
elif mode == 'raw': | ||
return loss | ||
|
||
|
||
def final_displacement_error( | ||
pred_pos, pred_pos_gt, consider_ped=None, mode='sum' | ||
): | ||
""" | ||
Input: | ||
- pred_pos: Tensor of shape (batch, 2). Predicted last pos. | ||
- pred_pos_gt: Tensor of shape (seq_len, batch, 2). Groud truth | ||
last pos | ||
- consider_ped: Tensor of shape (batch) | ||
Output: | ||
- loss: gives the eculidian displacement error | ||
""" | ||
loss = pred_pos_gt - pred_pos | ||
loss = loss**2 | ||
if consider_ped is not None: | ||
loss = torch.sqrt(loss.sum(dim=1)) * consider_ped | ||
else: | ||
loss = torch.sqrt(loss.sum(dim=1)) | ||
if mode == 'raw': | ||
return loss | ||
else: | ||
return torch.sum(loss) |
Oops, something went wrong.