Skip to content

Commit e1ab64f

Browse files
committed
First commit
0 parents  commit e1ab64f

File tree

150 files changed

+28055
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

150 files changed

+28055
-0
lines changed

.gitignore

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
**/__pycache__/
2+
**/embeddings/
3+
**/wandb/
4+
.ipynb_checkpoints
5+
src/results/
6+
.DS_Store
7+
.vscode
8+

LICENSE.md

+175
Large diffs are not rendered by default.

README.md

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# MMP
2+
3+
4+
<b>Multimodal Prototyping for cancer survival prediction</b>, ICML 2024.
5+
<br><em>Andrew H. Song, Richard J. Chen, Guillaume Jaume, Anurag Vaidya, Alexander S. Baras, Faisal Mahmood</em></br>
6+
7+
<img src="docs/mmp_logo.png" width="250px" align="right" />
8+
9+
[Paper](https://openreview.net/pdf?id=3MfvxH3Gia) | [Cite](#cite)
10+
11+
**Abstract:** Multimodal survival methods combining gigapixel histology whole-slide images (WSIs) and
12+
transcriptomic profiles are particularly promising for patient prognostication and stratification.
13+
Current approaches involve tokenizing the WSIs into smaller patches (> 10k patches) and transcriptomics into gene groups, which are then integrated using a Transformer for predicting outcomes. However, this process generates many
14+
tokens, which leads to high memory requirements for computing attention and complicates post-hoc interpretability analyses. Instead, we hypothesize that we can: (1) effectively summarize the morphological content of a WSI by
15+
condensing its constituting tokens using morphological prototypes, achieving more than 300× compression; and (2) accurately characterize cellular functions by encoding the transcriptomic profile with biological pathway prototypes, all
16+
in an unsupervised fashion.
17+
18+
We introduce **M**ulti**M**odal **P**rototyping framework (**MMP**), where the resulting multimodal tokens are then processed by a fusion network, either with a Transformer or an optimal transport cross-alignment, which now operates with a small and fixed number of tokens without approximations. Extensive evaluation shows that our framework outperforms state-of-the-art methods with much less computation while unlocking new interpretability analyses.
19+
20+
**MMP** (a.k.a. **M**ulti**M**odal **P**anther) is a multimodal extension of our companion work **PANTHER** (*CVPR 2024*, [paper](https://openaccess.thecvf.com/content/CVPR2024/html/Song_Morphological_Prototyping_for_Unsupervised_Slide_Representation_Learning_in_Computational_Pathology_CVPR_2024_paper.html), [code](https://github.com/mahmoodlab/PANTHER)), so we encourage you to check it out!
21+
22+
<img src="docs/fig1.jpg" width="1400px" align="center" />
23+
24+
## Updates
25+
- 07/02/2024: The first version of MMP codebase is now live!
26+
27+
## Installation
28+
Please run the following command to create MMP conda environment.
29+
```shell
30+
conda env create -f environment.yml
31+
```
32+
33+
## MMP Walkthrough
34+
MMP can largely be broken down into four steps:
35+
36+
**Step 1**: Construct histology prototypes (across the specific cancer cohort) and aggregate tissue patch tokens to the each prototype for each patient.\
37+
**Step 2**: Construct pathway prototypes and aggregate transcriptomic expression tokens to each prototype for each patient.\
38+
**Step 3**: Fuse aggegated histology and pathway embeddings and perform downstream task.\
39+
**Step 4**: Visualization.
40+
41+
### Step 1. Morphology prototype construction
42+
For instructions on **Step 1**, please refer to the instructions in [PANTHER](https://github.com/mahmoodlab/PANTHER).
43+
44+
### Step 2. Pathway prototype construction
45+
First, we need to download the pancancer-normalized TCGA transcriptomics expression data from Xena database.\
46+
Next, using **hallmark oncogene sets** (located in `src/data_csvs/rna/metadata/hallmarks_signatures.csv`), we filter the genes that are subset of hallmark pathways. Note that MMP can be extended to other pathways as well.
47+
Detailed instructions can be found in the [notebook](src/preprocess_pancancer_TCGA_normalized_RNA.ipynb).
48+
49+
### Step 3. Multimodal Fusion
50+
We can run a downstream task as follows (The data splits for TCGA cohorts used in our study can be found in `src/splits/survival`)
51+
```shell
52+
cd src
53+
./scripts/survival/brca_surv.sh 0 mmp
54+
```
55+
where [mmp](src/scripts/survival/mmp.sh) is a bash script that contains argument examples.
56+
57+
58+
59+
MMP currently supports
60+
- **Prototype-based multimodal fusion**: Two possible approaches. `model_mm_type=coattn` (Transformer-based full-attention) or `model_mm_type=coattn_mot` (OT-based cross-attention).
61+
- For histology aggregation approach, you can specify PANTHER or OT (`model_histo_type=PANTHER,default` or `model_histo_type=OT,default`)
62+
- **SurvPath**: Adapted from [SurvPath](https://github.com/mahmoodlab/SurvPath). Specify `model_mm_type=survpath` and `model_histo_type=mil,default`.
63+
- Example script available in [survpath](src/scripts/survival/survpath.sh).
64+
- **Unimodal prototype baselines**: Use either `model_mm_type=histo` (histology prototypes only) or `model_mm_type=gene` (pathway prototypes only).
65+
66+
67+
68+
### Step 4. Visualization
69+
70+
The instructions for visualizations of prototype assignment map and histology => pathway & pathway => histology interactions are explained in the [notebook](src/visualization/mmp_visualization.ipynb). Currently only `model_mm_type=coattn` is supported.
71+
72+
<img src='docs/heatmap.png' width="1400px" align="center"/>
73+
74+
## MMP future directions
75+
As emphasized in the paper, multimodal survival analysis is a challenging clinical task that has seen significant interest in the biomedical, computer vision, and machine learning communities. Though multimodal integration generally outperforms unimodal baselines, we note that the development of better unimodal baselines may (or may not) close the performance gap for certain cancer types, which is an area of further exploration.
76+
77+
## Acknowledgements
78+
If you find our work useful in your research or if you use parts of this code please cite our paper:
79+
80+
```bibtext
81+
@inproceedings{song2024multimodal,
82+
title={Multimodal Prototyping for cancer survival prediction},
83+
author={Song, Andrew H and Chen, Richard J and Jaume, Guillaume and Vaidya, Anurag Jayant and Baras, Alexander and Mahmood, Faisal},
84+
booktitle={Forty-first International Conference on Machine Learning},
85+
year={2024}
86+
}
87+
```
88+
89+
The code for **MMP** was adapted and inspired by the fantastic works of [PANTHER](https://openaccess.thecvf.com/content/CVPR2024/html/Song_Morphological_Prototyping_for_Unsupervised_Slide_Representation_Learning_in_Computational_Pathology_CVPR_2024_paper.html), [SurvPath](https://github.com/mahmoodlab/SurvPath) and [CLAM](https://github.com/mahmoodlab/CLAM). Boilerplate code for setting up supervised MIL benchmarks was developed by Ming Y. Lu and Tong Ding.
90+
91+
## Issues
92+
- Please open new threads or report issues directly (for urgent blockers) to `[email protected]`.
93+
- Immediate response to minor issues may not be available.
94+
95+
<img src=docs/joint_logo.png>

docs/fig1.jpg

1.12 MB
Loading

docs/heatmap.png

1.94 MB
Loading

docs/joint_logo.png

286 KB
Loading

docs/mmp.png

157 KB
Loading

docs/mmp_logo.png

601 KB
Loading

environment.yml

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
name: mmp
2+
channels:
3+
- pytorch
4+
- nvidia
5+
- conda-forge
6+
- defaults
7+
dependencies:
8+
- _libgcc_mutex=0.1=main
9+
- _openmp_mutex=5.1=1_gnu
10+
- asttokens=2.4.1=pyhd8ed1ab_0
11+
- blas=1.0=openblas
12+
- bzip2=1.0.8=h5eee18b_5
13+
- ca-certificates=2024.2.2=hbcca054_0
14+
- comm=0.2.2=pyhd8ed1ab_0
15+
- cudatoolkit=11.4.1=h8ab8bb3_9
16+
- debugpy=1.6.7=py310h6a678d5_0
17+
- decorator=5.1.1=pyhd8ed1ab_0
18+
- entrypoints=0.4=pyhd8ed1ab_0
19+
- exceptiongroup=1.2.0=pyhd8ed1ab_2
20+
- executing=2.0.1=pyhd8ed1ab_0
21+
- faiss-gpu=1.7.4=py3.10_hc0239a3_0_cuda11.4
22+
- ipykernel=6.29.3=pyhd33586a_0
23+
- ipython=8.22.2=pyh707e725_0
24+
- jedi=0.19.1=pyhd8ed1ab_0
25+
- jupyter_client=7.3.4=pyhd8ed1ab_0
26+
- jupyter_core=5.7.2=py310hff52083_0
27+
- ld_impl_linux-64=2.38=h1181459_1
28+
- libfaiss=1.7.4=h13c3c6d_0_cuda11.4
29+
- libffi=3.4.4=h6a678d5_0
30+
- libgcc-ng=11.2.0=h1234567_1
31+
- libgfortran-ng=11.2.0=h00389a5_1
32+
- libgfortran5=11.2.0=h1234567_1
33+
- libgomp=11.2.0=h1234567_1
34+
- libopenblas=0.3.21=h043d6bf_0
35+
- libsodium=1.0.18=h36c2ea0_1
36+
- libstdcxx-ng=11.2.0=h1234567_1
37+
- libuuid=1.41.5=h5eee18b_0
38+
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
39+
- ncurses=6.4=h6a678d5_0
40+
- nest-asyncio=1.6.0=pyhd8ed1ab_0
41+
- numpy=1.26.4=py310heeff2f4_0
42+
- numpy-base=1.26.4=py310h8a23956_0
43+
- openssl=3.0.13=h7f8727e_0
44+
- packaging=24.0=pyhd8ed1ab_0
45+
- parso=0.8.4=pyhd8ed1ab_0
46+
- pexpect=4.9.0=pyhd8ed1ab_0
47+
- pickleshare=0.7.5=py_1003
48+
- pip=23.3.1=py310h06a4308_0
49+
- platformdirs=4.2.0=pyhd8ed1ab_0
50+
- prompt-toolkit=3.0.42=pyha770c72_0
51+
- psutil=5.9.0=py310h5eee18b_0
52+
- ptyprocess=0.7.0=pyhd3deb0d_0
53+
- pure_eval=0.2.2=pyhd8ed1ab_0
54+
- pygments=2.17.2=pyhd8ed1ab_0
55+
- python=3.10.14=h955ad1f_0
56+
- python-dateutil=2.9.0=pyhd8ed1ab_0
57+
- python_abi=3.10=2_cp310
58+
- pyzmq=25.1.2=py310h6a678d5_0
59+
- readline=8.2=h5eee18b_0
60+
- setuptools=68.2.2=py310h06a4308_0
61+
- six=1.16.0=pyh6c4a22f_0
62+
- sqlite=3.41.2=h5eee18b_0
63+
- stack_data=0.6.2=pyhd8ed1ab_0
64+
- tk=8.6.12=h1ccaba5_0
65+
- tornado=6.1=py310h5764c6d_3
66+
- traitlets=5.14.3=pyhd8ed1ab_0
67+
- typing_extensions=4.11.0=pyha770c72_0
68+
- wcwidth=0.2.13=pyhd8ed1ab_0
69+
- wheel=0.41.2=py310h06a4308_0
70+
- xz=5.4.6=h5eee18b_0
71+
- zeromq=4.3.5=h6a678d5_0
72+
- zlib=1.2.13=h5eee18b_0
73+
- pip:
74+
- absl-py==2.1.0
75+
- appdirs==1.4.4
76+
- certifi==2024.2.2
77+
- charset-normalizer==3.3.2
78+
- click==8.1.7
79+
- contourpy==1.2.1
80+
- cycler==0.12.1
81+
- docker-pycreds==0.4.0
82+
- ecos==2.0.13
83+
- einops==0.7.0
84+
- filelock==3.13.4
85+
- fonttools==4.51.0
86+
- fsspec==2024.3.1
87+
- gitdb==4.0.11
88+
- gitpython==3.1.43
89+
- grpcio==1.62.2
90+
- h5py==3.11.0
91+
- huggingface-hub==0.22.2
92+
- idna==3.7
93+
- intel-openmp==2024.1.0
94+
- jinja2==3.1.3
95+
- joblib==1.4.0
96+
- kiwisolver==1.4.5
97+
- markdown==3.6
98+
- markupsafe==2.1.5
99+
- matplotlib==3.8.4
100+
- mkl==2024.1.0
101+
- mpmath==1.3.0
102+
- networkx==3.3
103+
- numexpr==2.10.0
104+
- nvidia-cublas-cu12==12.1.3.1
105+
- nvidia-cuda-cupti-cu12==12.1.105
106+
- nvidia-cuda-nvrtc-cu12==12.1.105
107+
- nvidia-cuda-runtime-cu12==12.1.105
108+
- nvidia-cudnn-cu12==8.9.2.26
109+
- nvidia-cufft-cu12==11.0.2.54
110+
- nvidia-curand-cu12==10.3.2.106
111+
- nvidia-cusolver-cu12==11.4.5.107
112+
- nvidia-cusparse-cu12==12.1.0.106
113+
- nvidia-nccl-cu12==2.19.3
114+
- nvidia-nvjitlink-cu12==12.4.127
115+
- nvidia-nvtx-cu12==12.1.105
116+
- nystrom-attention==0.0.12
117+
- osqp==0.6.5
118+
- pandas==2.2.2
119+
- pillow==10.3.0
120+
- protobuf==4.25.3
121+
- pyparsing==3.1.2
122+
- pytz==2024.1
123+
- pyyaml==6.0.1
124+
- qdldl==0.1.7.post2
125+
- regex==2024.4.16
126+
- requests==2.31.0
127+
- safetensors==0.4.3
128+
- scikit-learn==1.3.2
129+
- scikit-survival==0.22.2
130+
- scipy==1.11.4
131+
- seaborn==0.13.2
132+
- sentry-sdk==1.45.0
133+
- setproctitle==1.3.3
134+
- smmap==5.0.1
135+
- sympy==1.12
136+
- tbb==2021.12.0
137+
- tensorboard==2.16.2
138+
- tensorboard-data-server==0.7.2
139+
- threadpoolctl==3.4.0
140+
- tokenizers==0.19.1
141+
- torch==2.2.2
142+
- torchvision==0.17.2
143+
- tqdm==4.66.2
144+
- transformers==4.40.0
145+
- triton==2.2.0
146+
- tzdata==2024.1
147+
- urllib3==2.2.1
148+
- wandb==0.16.6
149+
- werkzeug==3.0.2

src/__init__.py

Whitespace-only changes.

src/configs/H2T_default/config.json

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"in_dim": 768,
3+
"n_classes": 2,
4+
"out_size": 8,
5+
"load_proto": false,
6+
"proto_path": ".",
7+
"fix_proto": false
8+
}

src/configs/OT_default/config.json

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"in_dim": 768,
3+
"n_classes": 2,
4+
"n_filters": 1024,
5+
"len_motifs": 1,
6+
"subsamplings": 1,
7+
"kernel_args": 0.4,
8+
"weight_decay": 0.0001,
9+
"embed_ratio": 16,
10+
"ot_eps": 0.1,
11+
"heads": 1,
12+
"out_size": 4,
13+
"out_type": "param_cat",
14+
"max_iter": 100,
15+
"distance": "euclidean",
16+
"fit_bias": false,
17+
"alternating": false,
18+
"load_proto": false,
19+
"proto_path": ".",
20+
"fix_proto": true
21+
}
+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"in_dim": 768,
3+
"n_classes": 2,
4+
"heads": 1,
5+
"em_iter": 1,
6+
"tau": 0.001,
7+
"ot_eps": 0.1,
8+
"n_fc_layers": 0,
9+
"dropout": 0.25,
10+
"out_type": "param_cat",
11+
"out_size": 8,
12+
"load_proto": false,
13+
"proto_path": ".",
14+
"fix_proto": false
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"in_dim": 768,
3+
"n_classes": 2,
4+
"out_size": 8,
5+
"load_proto": true,
6+
"proto_path": ".",
7+
"fix_proto": false
8+
}

src/data_csvs/rna/hallmarks/BLCA/rna_clean.csv

+361
Large diffs are not rendered by default.

src/data_csvs/rna/hallmarks/BRCA/rna_clean.csv

+940
Large diffs are not rendered by default.

src/data_csvs/rna/hallmarks/COADREAD/rna_clean.csv

+321
Large diffs are not rendered by default.

src/data_csvs/rna/hallmarks/KIRC/rna_clean.csv

+607
Large diffs are not rendered by default.

src/data_csvs/rna/hallmarks/LUAD/rna_clean.csv

+577
Large diffs are not rendered by default.

src/data_csvs/rna/hallmarks/STAD/rna_clean.csv

+337
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)