Skip to content

Commit

Permalink
TwoStream
Browse files Browse the repository at this point in the history
  • Loading branch information
2000ZRL committed Mar 20, 2023
1 parent 99fd3a4 commit 7b8822e
Show file tree
Hide file tree
Showing 66 changed files with 9,829 additions and 13 deletions.
142 changes: 142 additions & 0 deletions TwoStreamNetwork/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
checkpoints/
experiments_/
wandb
*.ipynb
*.sh
data/csl-daily/*.zip
data/csl-daily/*keypoints*
data/phoenix-2014*/*.zip
data/phoenix-2014*/*keypoints*
experiments/outputs
experiments/configs
experiments/scripts
pretrained_models/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
95 changes: 82 additions & 13 deletions TwoStreamNetwork/README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,92 @@
# [NeurIPS XXX] Papaer title
# TwoStreamSLT
A TwoStream network for sign language recognition and translation, including official implementations for
* [A Simple Multi-modality Transfer Learning Baseline for Sign Language Translation, CVPR2022](https://arxiv.org/abs/2203.04287)
* [Two-Stream Network for Sign Language Recognition and Translation, NeurIPS2022](https://arxiv.org/abs/2211.01367).

## Introduction
Sign Language Translation (SLT) and Sign Language Recognition (SLR) suffer from data scarcity. To mitigate this problem, we first propose [a simple multi-modality transfer learning baseline for SLT](https://arxiv.org/abs/2203.04287), which leverages extra supervision from large-scale general-domain datasets by progressively pretraining modules from general domains to within domains, and finally conducting multi-modal joint training. This simple yet effective baseline achieves strong translation performance, significantly improving over previous works.

## Main Results
<img src="images/baseline.png" width="800">

We further propose a [twostream network for SLR and SLT](https://arxiv.org/abs/2211.01367), which incorporates domain knowledge of human keypoints into the visual encoder. The TwoStream network obtains SOTA performances across SLR and SLT benchmarks (`18.8 WER on Phoenix-2014 and 19.3 WER on Phoenix-2014T, 29.0 BLEU4 on Phoenix-2014T, and 25.8 BLEU4 on CSL-Daily`).

## Data preparation
<img src="images/TwoStream_illustration.png" width="750">

## Quick start
### Installation
1.
2.
3.
### Training
## Performance

### Testing
Pre-trained models can be found in [GoogleDrive]() or [BaiduCloud]().
**SingleStream-SLT (The simple multi-modality transfer learning baseline for SLT)**
| Dataset | R | B1 | B2 | B3 | B4 | Model | Training |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Phoenix-2014T | 53.08 | 54.48 | 41.93 | 33.97 | 28.57 | [ckpt]() | [config](experiments/configs/SingleStream/phoenix-2014t_s2t.yaml) |
| CSL-Daily | 53.35 | 53.53 | 40.68 | 31.04 | 24.09 |[ckpt]() | [config](experiments/configs/SingleStream/csl-daily_s2t.yaml) |

## Citation
Please cite this work if you find this repo is helpful.
**Twostream-SLR**
| Dataset | WER | Model | Training |
| :---: | :---: | :---: | :---: |
| Phoenix-2014 | 18.8 | [ckpt]() | [config](experiments/configs/TwoStream/phoenix-2014_s2g.yaml) |
| Phoenix-2014T | 19.3 | [ckpt]() | [config](experiments/configs/TwoStream/phoenix-2014t_s2g.yaml) |
| CSL-Daily | 25.3 | [ckpt]() | [config](experiments/configs/TwoStream/csl-daily_s2g.yaml) |

**Twostream-SLT**
| Dataset | R | B1 | B2 | B3 | B4 | Model | Training |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Phoenix-2014T | 53.48 | 54.90 | 42.43 | 34.46 | 28.95 | [ckpt]() | [config](experiments/configs/TwoStream/phoenix-2014t_s2t_ensemble.yaml) |
| CSL-Daily | 55.72 | 55.44 | 42.59 | 32.87 | 25.79 | [ckpt]() | [config](experiments/configs/TwoStream/csl-daily_s2t_ensemble.yaml) |

## Usage
### Prerequisites
Create an environment and install dependencies.
```
conda env create -f environment.yml
conda activate slt
```
### Download
You can run [download.sh](download.sh) which automatically downloads datasets (except CSL-Daily, whose downloading needs an agreement submission), pretrained models, keypoints and place them under corresponding locations. Or you can download these files separately as follows.

**Datasets**

Download datasets from their websites and place them under the corresponding directories in data/
* [Phoenix-2014](https://www-i6.informatik.rwth-aachen.de/~koller/RWTH-PHOENIX/)
* [Phoenix-2014T](https://www-i6.informatik.rwth-aachen.de/~koller/RWTH-PHOENIX-2014-T/)
* [CSL-Daily](http://home.ustc.edu.cn/~zhouh156/dataset/csl-daily/)

Then run [preprocess/preprocess_video.sh](preprocess/preprocess_video.sh) to extract the downloaded videos.

**Pretrained Models**
We provide pretrained models [here](). Download this directory and place it as *pretrained_models*. Specifically, the required pretrained models include:
* *s3ds_actioncls_ckpt*: S3D backbone pretrained on Kinetics-400. (From [https://github.com/kylemin/S3D](https://github.com/kylemin/S3D). Thanks for their implementation!)
* *s3ds_glosscls_ckpt*: S3D backbone pretrained on Kinetics-400 and WLASL.
* *mbart_de* / *mbart_zh* : pretrained language models used to initialize the translation network for German and Chinese, with weights from [mbart-cc-25](https://huggingface.co/facebook/mbart-large-cc25). We prune mbart's original word embedding by preserving only German or Chinese tokens to avoid GPU out-of-memory. We also compute gloss embeddings by averaging mBart-pretrained embeddings of all sub-tokens of the gloss. (See [utils/prune_embedding.ipynb](utils/prune_embedding.ipynb))

**Keypoints** (Only needed in TwoStream)
We provide [human keypoints]() for three datasets pre-extracted by HRNet. Please download them and place them under *data/phoenix-2014t(phoenix-2014 or csl-daily)*.


### Training and Evaluation

* For **SingleStream-SLT Baseline**, please see [SingleStream-SLT.md](docs/SingleStream-SLT.md).
* For **TwoStream-SLR**, please see [TwoStream-SLR.md](docs/TwoStream-SLR.md).
* For **TwoStream-SLT**, please see [TwoStream-SLT.md](docs/TwoStream-SLT.md). (Based on TwoStream-SLR)

## Citations
```
@inproceedings{
chen2022twostream,
title={Two-Stream Network for Sign Language Recognition and Translation},
author={Yutong Chen and Ronglai Zuo and Fangyun Wei and Yu Wu and Shujie LIU and Brian Mak},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=hSxK-4KGLbI}
}
@InProceedings{
Chen_2022_CVPR,
author = {Chen, Yutong and Wei, Fangyun and Sun, Xiao and Wu, Zhirong and Lin, Stephen},
title = {A Simple Multi-Modality Transfer Learning Baseline for Sign Language Translation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {5120-5130}
}
```

Binary file added TwoStreamNetwork/data/csl-daily/csl-daily.dev
Binary file not shown.
Binary file added TwoStreamNetwork/data/csl-daily/csl-daily.test
Binary file not shown.
Binary file added TwoStreamNetwork/data/csl-daily/csl-daily.train
Binary file not shown.
Binary file added TwoStreamNetwork/data/csl-daily/gloss2ids.pkl
Binary file not shown.
Binary file added TwoStreamNetwork/data/phoenix-2014/gloss2ids.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added TwoStreamNetwork/data/phoenix-2014t/gloss2ids.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
87 changes: 87 additions & 0 deletions TwoStreamNetwork/dataset/Dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from dataset.VideoLoader import load_batch_video
from dataset.FeatureLoader import load_batch_feature
from dataset.Dataset import build_dataset
import torch


def collate_fn_(inputs, data_cfg, task, is_train,
text_tokenizer=None, gloss_tokenizer=None,name2keypoint=None):
outputs = {
'name':[i['name'] for i in inputs],
'gloss':[i.get('gloss','') for i in inputs],
'text':[i.get('text','') for i in inputs],
'num_frames':[i['num_frames'] for i in inputs]}
if task == 'S2G':
outputs['recognition_inputs'] = gloss_tokenizer(outputs['gloss'])

sgn_videos, sgn_keypoints, sgn_lengths = load_batch_video(
zip_file=data_cfg['zip_file'],
names=outputs['name'],
num_frames=outputs['num_frames'],
transform_cfg=data_cfg['transform_cfg'],
dataset_name=data_cfg['dataset_name'],
pad_length=data_cfg.get('pad_length','pad_to_max'),
pad = data_cfg.get('pad','replicate'),
is_train=is_train,
name2keypoint=name2keypoint,
)
outputs['recognition_inputs']['sgn_videos'] = sgn_videos
outputs['recognition_inputs']['sgn_keypoints'] = sgn_keypoints
outputs['recognition_inputs']['sgn_lengths'] = sgn_lengths


if task in ['S2T','G2T','S2T_Ensemble']:
tokenized_text = text_tokenizer(input_str=outputs['text'])
outputs['translation_inputs'] = {**tokenized_text}
if task == 'S2T':
outputs['recognition_inputs'] = gloss_tokenizer(outputs['gloss'])
outputs['translation_inputs']['gloss_ids'] = outputs['recognition_inputs']['gloss_labels']
outputs['translation_inputs']['gloss_lengths'] = outputs['recognition_inputs']['gls_lengths']
for feature_name in ['sgn_features', 'head_rgb_input','head_keypoint_input']:
if feature_name in inputs[0]:
outputs['recognition_inputs'][feature_name], sgn_mask, sgn_lengths = \
load_batch_feature(features=[i[feature_name]+1.0e-8 for i in inputs])
outputs['recognition_inputs']['sgn_mask'] = sgn_mask
outputs['recognition_inputs']['sgn_lengths'] = sgn_lengths
elif task == 'G2T':
tokenized_gloss = gloss_tokenizer(batch_gls_seq=outputs['gloss'])
outputs['translation_inputs']['input_ids'] = tokenized_gloss['input_ids']
outputs['translation_inputs']['attention_mask'] = tokenized_gloss['attention_mask']
elif task == 'S2T_Ensemble':
outputs['translation_inputs']['inputs_embeds_list'] = []
outputs['translation_inputs']['attention_mask_list'] = []
for ii in range(len(inputs[0]['inputs_embeds_list'])):
inputs_embeds, mask_ ,_= load_batch_feature(features=[i['inputs_embeds_list'][ii] for i in inputs])
outputs['translation_inputs']['inputs_embeds_list'].append(inputs_embeds)
outputs['translation_inputs']['attention_mask_list'].append(mask_)
return outputs

def build_dataloader(cfg, split,
text_tokenizer=None, gloss_tokenizer=None,
mode='auto', val_distributed=False):
dataset = build_dataset(cfg['data'], split)
mode = split if mode=='auto' else mode
if mode=='train':
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
shuffle=cfg['training']['shuffle'] and split=='train'
)
else:
if val_distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
dataloader = torch.utils.data.DataLoader(dataset,
collate_fn=lambda x:collate_fn_(
inputs=x,
task=cfg['task'],
data_cfg=cfg['data'],
is_train=(mode=='train'),
text_tokenizer=text_tokenizer,
gloss_tokenizer=gloss_tokenizer,
name2keypoint=dataset.name2keypoints),
batch_size=cfg['training']['batch_size'],
num_workers=cfg['training'].get('num_workers',2),
sampler=sampler,
)
return dataloader, sampler
Loading

0 comments on commit 7b8822e

Please sign in to comment.