forked from FangyunWei/SLRT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
66 changed files
with
9,829 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
checkpoints/ | ||
experiments_/ | ||
wandb | ||
*.ipynb | ||
*.sh | ||
data/csl-daily/*.zip | ||
data/csl-daily/*keypoints* | ||
data/phoenix-2014*/*.zip | ||
data/phoenix-2014*/*keypoints* | ||
experiments/outputs | ||
experiments/configs | ||
experiments/scripts | ||
pretrained_models/ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,92 @@ | ||
# [NeurIPS XXX] Papaer title | ||
# TwoStreamSLT | ||
A TwoStream network for sign language recognition and translation, including official implementations for | ||
* [A Simple Multi-modality Transfer Learning Baseline for Sign Language Translation, CVPR2022](https://arxiv.org/abs/2203.04287) | ||
* [Two-Stream Network for Sign Language Recognition and Translation, NeurIPS2022](https://arxiv.org/abs/2211.01367). | ||
|
||
## Introduction | ||
Sign Language Translation (SLT) and Sign Language Recognition (SLR) suffer from data scarcity. To mitigate this problem, we first propose [a simple multi-modality transfer learning baseline for SLT](https://arxiv.org/abs/2203.04287), which leverages extra supervision from large-scale general-domain datasets by progressively pretraining modules from general domains to within domains, and finally conducting multi-modal joint training. This simple yet effective baseline achieves strong translation performance, significantly improving over previous works. | ||
|
||
## Main Results | ||
<img src="images/baseline.png" width="800"> | ||
|
||
We further propose a [twostream network for SLR and SLT](https://arxiv.org/abs/2211.01367), which incorporates domain knowledge of human keypoints into the visual encoder. The TwoStream network obtains SOTA performances across SLR and SLT benchmarks (`18.8 WER on Phoenix-2014 and 19.3 WER on Phoenix-2014T, 29.0 BLEU4 on Phoenix-2014T, and 25.8 BLEU4 on CSL-Daily`). | ||
|
||
## Data preparation | ||
<img src="images/TwoStream_illustration.png" width="750"> | ||
|
||
## Quick start | ||
### Installation | ||
1. | ||
2. | ||
3. | ||
### Training | ||
## Performance | ||
|
||
### Testing | ||
Pre-trained models can be found in [GoogleDrive]() or [BaiduCloud](). | ||
**SingleStream-SLT (The simple multi-modality transfer learning baseline for SLT)** | ||
| Dataset | R | B1 | B2 | B3 | B4 | Model | Training | | ||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ||
| Phoenix-2014T | 53.08 | 54.48 | 41.93 | 33.97 | 28.57 | [ckpt]() | [config](experiments/configs/SingleStream/phoenix-2014t_s2t.yaml) | | ||
| CSL-Daily | 53.35 | 53.53 | 40.68 | 31.04 | 24.09 |[ckpt]() | [config](experiments/configs/SingleStream/csl-daily_s2t.yaml) | | ||
|
||
## Citation | ||
Please cite this work if you find this repo is helpful. | ||
**Twostream-SLR** | ||
| Dataset | WER | Model | Training | | ||
| :---: | :---: | :---: | :---: | | ||
| Phoenix-2014 | 18.8 | [ckpt]() | [config](experiments/configs/TwoStream/phoenix-2014_s2g.yaml) | | ||
| Phoenix-2014T | 19.3 | [ckpt]() | [config](experiments/configs/TwoStream/phoenix-2014t_s2g.yaml) | | ||
| CSL-Daily | 25.3 | [ckpt]() | [config](experiments/configs/TwoStream/csl-daily_s2g.yaml) | | ||
|
||
**Twostream-SLT** | ||
| Dataset | R | B1 | B2 | B3 | B4 | Model | Training | | ||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ||
| Phoenix-2014T | 53.48 | 54.90 | 42.43 | 34.46 | 28.95 | [ckpt]() | [config](experiments/configs/TwoStream/phoenix-2014t_s2t_ensemble.yaml) | | ||
| CSL-Daily | 55.72 | 55.44 | 42.59 | 32.87 | 25.79 | [ckpt]() | [config](experiments/configs/TwoStream/csl-daily_s2t_ensemble.yaml) | | ||
|
||
## Usage | ||
### Prerequisites | ||
Create an environment and install dependencies. | ||
``` | ||
conda env create -f environment.yml | ||
conda activate slt | ||
``` | ||
### Download | ||
You can run [download.sh](download.sh) which automatically downloads datasets (except CSL-Daily, whose downloading needs an agreement submission), pretrained models, keypoints and place them under corresponding locations. Or you can download these files separately as follows. | ||
|
||
**Datasets** | ||
|
||
Download datasets from their websites and place them under the corresponding directories in data/ | ||
* [Phoenix-2014](https://www-i6.informatik.rwth-aachen.de/~koller/RWTH-PHOENIX/) | ||
* [Phoenix-2014T](https://www-i6.informatik.rwth-aachen.de/~koller/RWTH-PHOENIX-2014-T/) | ||
* [CSL-Daily](http://home.ustc.edu.cn/~zhouh156/dataset/csl-daily/) | ||
|
||
Then run [preprocess/preprocess_video.sh](preprocess/preprocess_video.sh) to extract the downloaded videos. | ||
|
||
**Pretrained Models** | ||
We provide pretrained models [here](). Download this directory and place it as *pretrained_models*. Specifically, the required pretrained models include: | ||
* *s3ds_actioncls_ckpt*: S3D backbone pretrained on Kinetics-400. (From [https://github.com/kylemin/S3D](https://github.com/kylemin/S3D). Thanks for their implementation!) | ||
* *s3ds_glosscls_ckpt*: S3D backbone pretrained on Kinetics-400 and WLASL. | ||
* *mbart_de* / *mbart_zh* : pretrained language models used to initialize the translation network for German and Chinese, with weights from [mbart-cc-25](https://huggingface.co/facebook/mbart-large-cc25). We prune mbart's original word embedding by preserving only German or Chinese tokens to avoid GPU out-of-memory. We also compute gloss embeddings by averaging mBart-pretrained embeddings of all sub-tokens of the gloss. (See [utils/prune_embedding.ipynb](utils/prune_embedding.ipynb)) | ||
|
||
**Keypoints** (Only needed in TwoStream) | ||
We provide [human keypoints]() for three datasets pre-extracted by HRNet. Please download them and place them under *data/phoenix-2014t(phoenix-2014 or csl-daily)*. | ||
|
||
|
||
### Training and Evaluation | ||
|
||
* For **SingleStream-SLT Baseline**, please see [SingleStream-SLT.md](docs/SingleStream-SLT.md). | ||
* For **TwoStream-SLR**, please see [TwoStream-SLR.md](docs/TwoStream-SLR.md). | ||
* For **TwoStream-SLT**, please see [TwoStream-SLT.md](docs/TwoStream-SLT.md). (Based on TwoStream-SLR) | ||
|
||
## Citations | ||
``` | ||
@inproceedings{ | ||
chen2022twostream, | ||
title={Two-Stream Network for Sign Language Recognition and Translation}, | ||
author={Yutong Chen and Ronglai Zuo and Fangyun Wei and Yu Wu and Shujie LIU and Brian Mak}, | ||
booktitle={Advances in Neural Information Processing Systems}, | ||
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, | ||
year={2022}, | ||
url={https://openreview.net/forum?id=hSxK-4KGLbI} | ||
} | ||
@InProceedings{ | ||
Chen_2022_CVPR, | ||
author = {Chen, Yutong and Wei, Fangyun and Sun, Xiao and Wu, Zhirong and Lin, Stephen}, | ||
title = {A Simple Multi-Modality Transfer Learning Baseline for Sign Language Translation}, | ||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
month = {June}, | ||
year = {2022}, | ||
pages = {5120-5130} | ||
} | ||
``` | ||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from dataset.VideoLoader import load_batch_video | ||
from dataset.FeatureLoader import load_batch_feature | ||
from dataset.Dataset import build_dataset | ||
import torch | ||
|
||
|
||
def collate_fn_(inputs, data_cfg, task, is_train, | ||
text_tokenizer=None, gloss_tokenizer=None,name2keypoint=None): | ||
outputs = { | ||
'name':[i['name'] for i in inputs], | ||
'gloss':[i.get('gloss','') for i in inputs], | ||
'text':[i.get('text','') for i in inputs], | ||
'num_frames':[i['num_frames'] for i in inputs]} | ||
if task == 'S2G': | ||
outputs['recognition_inputs'] = gloss_tokenizer(outputs['gloss']) | ||
|
||
sgn_videos, sgn_keypoints, sgn_lengths = load_batch_video( | ||
zip_file=data_cfg['zip_file'], | ||
names=outputs['name'], | ||
num_frames=outputs['num_frames'], | ||
transform_cfg=data_cfg['transform_cfg'], | ||
dataset_name=data_cfg['dataset_name'], | ||
pad_length=data_cfg.get('pad_length','pad_to_max'), | ||
pad = data_cfg.get('pad','replicate'), | ||
is_train=is_train, | ||
name2keypoint=name2keypoint, | ||
) | ||
outputs['recognition_inputs']['sgn_videos'] = sgn_videos | ||
outputs['recognition_inputs']['sgn_keypoints'] = sgn_keypoints | ||
outputs['recognition_inputs']['sgn_lengths'] = sgn_lengths | ||
|
||
|
||
if task in ['S2T','G2T','S2T_Ensemble']: | ||
tokenized_text = text_tokenizer(input_str=outputs['text']) | ||
outputs['translation_inputs'] = {**tokenized_text} | ||
if task == 'S2T': | ||
outputs['recognition_inputs'] = gloss_tokenizer(outputs['gloss']) | ||
outputs['translation_inputs']['gloss_ids'] = outputs['recognition_inputs']['gloss_labels'] | ||
outputs['translation_inputs']['gloss_lengths'] = outputs['recognition_inputs']['gls_lengths'] | ||
for feature_name in ['sgn_features', 'head_rgb_input','head_keypoint_input']: | ||
if feature_name in inputs[0]: | ||
outputs['recognition_inputs'][feature_name], sgn_mask, sgn_lengths = \ | ||
load_batch_feature(features=[i[feature_name]+1.0e-8 for i in inputs]) | ||
outputs['recognition_inputs']['sgn_mask'] = sgn_mask | ||
outputs['recognition_inputs']['sgn_lengths'] = sgn_lengths | ||
elif task == 'G2T': | ||
tokenized_gloss = gloss_tokenizer(batch_gls_seq=outputs['gloss']) | ||
outputs['translation_inputs']['input_ids'] = tokenized_gloss['input_ids'] | ||
outputs['translation_inputs']['attention_mask'] = tokenized_gloss['attention_mask'] | ||
elif task == 'S2T_Ensemble': | ||
outputs['translation_inputs']['inputs_embeds_list'] = [] | ||
outputs['translation_inputs']['attention_mask_list'] = [] | ||
for ii in range(len(inputs[0]['inputs_embeds_list'])): | ||
inputs_embeds, mask_ ,_= load_batch_feature(features=[i['inputs_embeds_list'][ii] for i in inputs]) | ||
outputs['translation_inputs']['inputs_embeds_list'].append(inputs_embeds) | ||
outputs['translation_inputs']['attention_mask_list'].append(mask_) | ||
return outputs | ||
|
||
def build_dataloader(cfg, split, | ||
text_tokenizer=None, gloss_tokenizer=None, | ||
mode='auto', val_distributed=False): | ||
dataset = build_dataset(cfg['data'], split) | ||
mode = split if mode=='auto' else mode | ||
if mode=='train': | ||
sampler = torch.utils.data.distributed.DistributedSampler( | ||
dataset, | ||
shuffle=cfg['training']['shuffle'] and split=='train' | ||
) | ||
else: | ||
if val_distributed: | ||
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) | ||
else: | ||
sampler = torch.utils.data.SequentialSampler(dataset) | ||
dataloader = torch.utils.data.DataLoader(dataset, | ||
collate_fn=lambda x:collate_fn_( | ||
inputs=x, | ||
task=cfg['task'], | ||
data_cfg=cfg['data'], | ||
is_train=(mode=='train'), | ||
text_tokenizer=text_tokenizer, | ||
gloss_tokenizer=gloss_tokenizer, | ||
name2keypoint=dataset.name2keypoints), | ||
batch_size=cfg['training']['batch_size'], | ||
num_workers=cfg['training'].get('num_workers',2), | ||
sampler=sampler, | ||
) | ||
return dataloader, sampler |
Oops, something went wrong.