Skip to content

Commit 8c47e7d

Browse files
committed
Minor update to env.yaml
1 parent 0d80e52 commit 8c47e7d

File tree

7 files changed

+83
-156
lines changed

7 files changed

+83
-156
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ We introduce **M**ulti**M**odal **P**rototyping framework (**MMP**), where the r
2525
- 07/02/2024: The first version of MMP codebase is now live!
2626

2727
## Installation
28-
Please run the following command to create MMP conda environment.
28+
Once you clone the repo, please run the following command to create MMP conda environment.
2929
```shell
30-
conda env create -f environment.yml
30+
conda env create -f env.yaml
3131
```
3232

3333
## MMP Walkthrough

env.yaml

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
name: mmp
2+
channels:
3+
- defaults
4+
dependencies:
5+
- _libgcc_mutex=0.1=main
6+
- _openmp_mutex=5.1=1_gnu
7+
- ca-certificates=2024.3.11=h06a4308_0
8+
- ld_impl_linux-64=2.38=h1181459_1
9+
- libffi=3.4.4=h6a678d5_1
10+
- libgcc-ng=11.2.0=h1234567_1
11+
- libgomp=11.2.0=h1234567_1
12+
- libstdcxx-ng=11.2.0=h1234567_1
13+
- ncurses=6.4=h6a678d5_0
14+
- openssl=3.0.14=h5eee18b_0
15+
- pip=24.0=py39h06a4308_0
16+
- python=3.9.19=h955ad1f_1
17+
- readline=8.2=h5eee18b_0
18+
- setuptools=69.5.1=py39h06a4308_0
19+
- sqlite=3.45.3=h5eee18b_0
20+
- tk=8.6.14=h39e8969_0
21+
- wheel=0.43.0=py39h06a4308_0
22+
- xz=5.4.6=h5eee18b_1
23+
- zlib=1.2.13=h5eee18b_1
24+
- pip:
25+
- certifi==2024.6.2
26+
- charset-normalizer==3.3.2
27+
- ecos==2.0.14
28+
- einops==0.8.0
29+
- faiss-gpu==1.7.2
30+
- filelock==3.15.4
31+
- fsspec==2024.6.1
32+
- h5py==3.11.0
33+
- huggingface-hub==0.23.4
34+
- idna==3.7
35+
- jinja2==3.1.4
36+
- joblib==1.4.2
37+
- markupsafe==2.1.5
38+
- mpmath==1.3.0
39+
- networkx==3.2.1
40+
- numexpr==2.10.1
41+
- numpy==1.26.4
42+
- nvidia-cublas-cu12==12.1.3.1
43+
- nvidia-cuda-cupti-cu12==12.1.105
44+
- nvidia-cuda-nvrtc-cu12==12.1.105
45+
- nvidia-cuda-runtime-cu12==12.1.105
46+
- nvidia-cudnn-cu12==8.9.2.26
47+
- nvidia-cufft-cu12==11.0.2.54
48+
- nvidia-curand-cu12==10.3.2.106
49+
- nvidia-cusolver-cu12==11.4.5.107
50+
- nvidia-cusparse-cu12==12.1.0.106
51+
- nvidia-nccl-cu12==2.20.5
52+
- nvidia-nvjitlink-cu12==12.5.82
53+
- nvidia-nvtx-cu12==12.1.105
54+
- osqp==0.6.7.post0
55+
- packaging==24.1
56+
- pandas==2.2.2
57+
- pot==0.9.4
58+
- python-dateutil==2.9.0.post0
59+
- pytz==2024.1
60+
- pyyaml==6.0.1
61+
- qdldl==0.1.7.post4
62+
- regex==2024.5.15
63+
- requests==2.32.3
64+
- safetensors==0.4.3
65+
- scikit-learn==1.5.0
66+
- scikit-survival==0.23.0
67+
- scipy==1.13.1
68+
- six==1.16.0
69+
- sympy==1.12.1
70+
- threadpoolctl==3.5.0
71+
- tokenizers==0.19.1
72+
- torch==2.3.1
73+
- tqdm==4.66.4
74+
- transformers==4.42.3
75+
- triton==2.3.1
76+
- typing-extensions==4.12.2
77+
- tzdata==2024.1
78+
- urllib3==2.2.2

environment.yml

-149
This file was deleted.

src/scripts/prototype/clustering.sh

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ mode='faiss' # 'faiss' or 'kmeans'
1414
n_proto=16 # Number of prototypes
1515
n_init=3 # Number of KMeans initializations to perform
1616

17+
1718
# Validity check for feat paths
1819
all_feat_dirs=""
1920
for dataroot_path in "${dataroots[@]}"; do

src/scripts/survival/brca_surv.sh

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ config=$2
66
### Dataset Information
77
declare -a dataroots=(
88
'path/to/tcga_brca'/
9-
109
)
1110

1211
task='BRCA_survival'

src/training/trainer.py

-3
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,13 @@ def train(datasets, args):
108108
train_results = train_loop_survival(model, datasets['train'], optimizer, lr_scheduler, loss_fn,
109109
print_every=args.print_every, accum_steps=args.accum_steps)
110110

111-
writer = log_dict_tensorboard(writer, train_results, 'train/', epoch)
112111

113112
### Validation Loop (Optional)
114113
if 'val' in datasets.keys():
115114
print('#' * 11, f'VAL Epoch: {epoch}', '#' * 11)
116115
val_results, _ = validate_survival(model, datasets['val'], loss_fn,
117116
print_every=args.print_every, verbose=True)
118117

119-
writer = log_dict_tensorboard(writer, val_results, 'val/', epoch)
120-
121118
### Check Early Stopping (Optional)
122119
if early_stopper is not None:
123120
if args.es_metric == 'loss':

src/utils/proto_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def cluster(data_loader, n_proto, n_iter, n_init=5, feature_dim=1024, n_proto_pa
8080
verbose=True,
8181
max_points_per_centroid=n_proto_patches,
8282
gpu=numOfGPUs)
83-
kmeans.train(patches)
83+
84+
kmeans.train(patches.numpy())
8485
weight = kmeans.centroids[np.newaxis, ...]
8586

8687
else:

0 commit comments

Comments
 (0)