Skip to content

Commit

Permalink
Merge pull request snap-stanford#8 from snap-stanford/new_species_tok…
Browse files Browse the repository at this point in the history
…en_fix

Simplify Embedding of novel species
  • Loading branch information
yhr91 authored Dec 19, 2023
2 parents cce1ba1 + e4dede8 commit 3ded79b
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 25 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ outputs from `eval_single_anndata.py` are stored in the `dir` directory.

You can download processed datasets used in the papere [here](https://drive.google.com/drive/folders/1f63fh0ykgEhCrkd_EVvIootBw7LYDVI7?usp=drive_link)

**Note:** These datasets were embedded using the 33 layer model. Embeddings for the 33 layer model are not compatible with embeddings from the 4 layer model.

## Citing

If you find our paper and code useful, please consider citing the [preprint](https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1):
Expand Down
50 changes: 32 additions & 18 deletions data_proc/Create New Species Files.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -540,26 +540,19 @@
}
],
"source": [
"torch.manual_seed(8) # DEFAULT SEED FOR GENERATING \n",
"MASK_TENSOR = torch.zeros((1, token_dim)) # this is the padding token\n",
"CHROM_TENSOR_LEFT = torch.normal(mean=0, std=1, size=(1, token_dim))\n",
"CHROM_TENSOR_RIGHT = torch.normal(mean=0, std=1, size=(1, token_dim))\n",
"CLS_TENSOR = torch.normal(mean=0, std=1, size=(1, token_dim))\n",
"\n",
"\n",
"\n",
"species_to_offsets = {}\n",
"\n",
"all_pe = [MASK_TENSOR, CHROM_TENSOR_LEFT, CHROM_TENSOR_RIGHT, CLS_TENSOR]\n",
"all_pe = torch.load(\"../model_files/all_tokens.torch\")[0:4] # read in existing token file to make sure \n",
"# that special vocab tokens are the same for different seeds\n",
"\n",
"offset = len(all_pe) # special tokens at the top!\n",
"\n",
"PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)\n",
"\n",
"pe_stacked = torch.stack(list(PE.values()))\n",
"all_pe.append(pe_stacked)\n",
"all_pe = torch.vstack((all_pe, pe_stacked))\n",
"species_to_offsets[species] = offset\n",
"\n",
"all_pe = torch.vstack(all_pe)\n",
"print(\"CHROM_TOKEN_OFFSET:\", all_pe.shape[0])\n",
"torch.manual_seed(TAXONOMY_ID)\n",
"CHROM_TENSORS = torch.normal(mean=0, std=1, size=(N_UNIQ_CHROM, 5120)) \n",
Expand Down Expand Up @@ -602,6 +595,29 @@
{
"cell_type": "code",
"execution_count": 16,
"id": "21f937ea",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([13341, 5120])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all_pe.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "5faadace",
"metadata": {},
"outputs": [
Expand All @@ -619,7 +635,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "6ceac20b",
"metadata": {},
"outputs": [
Expand All @@ -629,7 +645,7 @@
"'../model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt'"
]
},
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -653,12 +669,10 @@
"source": [
"**Note: when you evaluate a new species, you need to change some arguments and modify some files:**\n",
"\n",
"You will need to modify the codebase to include the new protein embeddings file you downloaded.\n",
"\n",
"In the file `data_proc/gene_embeddings.py`, please add a line corresponding to the new species in the dictionary for ESM2 embeddings created on line 13.\n",
"You will need to modify the csv in `model_files/new_species_protein_embeddings.csv` to include the new protein embeddings file you downloaded.\n",
"\n",
"The format should match the existing species already listed there. For chicken, it would be:\n",
"`\"chicken\": EMBEDDING_DIR / 'Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt',`\n",
"In the file add a row for the new species with the format:\n",
"`species name,full path to protein embedding file`\n",
"\n",
"Please also add this line to the dictionary created on line 247 in the file `data_proc/data_utils.py`.\n",
"\n",
Expand Down
Binary file modified data_proc/__pycache__/data_utils.cpython-38.pyc
Binary file not shown.
Binary file modified data_proc/__pycache__/gene_embeddings.cpython-38.pyc
Binary file not shown.
9 changes: 7 additions & 2 deletions data_proc/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ def process_raw_anndata(row, h5_folder_path, npz_folder_path, scp, skip,
path = row.path
if not os.path.isfile(root + "/" + path):
print( "**********************************")
print(f"***********{path} ERROR***********")
print(f"***********{root + '/' + path} File Missing****")
print( "**********************************")
print(path, root)
return None

name = path.replace(".h5ad", "")
Expand Down Expand Up @@ -254,7 +255,11 @@ def get_species_to_pe(EMBEDDING_DIR):
"macaca_fascicularis": EMBEDDING_DIR / 'Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt',
"macaca_mulatta": EMBEDDING_DIR / 'Macaca_mulatta.Mmul_10.gene_symbol_to_embedding_ESM2.pt',
}

extra_species = pd.read_csv("./model_files/new_species_protein_embeddings.csv").set_index("species").to_dict()["path"]
embeddings_paths.update(extra_species) # adds new species



species_to_pe = {
species:torch.load(pe_dir) for species, pe_dir in embeddings_paths.items()
}
Expand Down
4 changes: 3 additions & 1 deletion data_proc/gene_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from scanpy import AnnData
import numpy as np
import pandas as pd


EMBEDDING_DIR = Path('model_files/protein_embeddings')
Expand All @@ -14,7 +15,6 @@
'human': EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt',
'mouse': EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM2.pt',
'frog': EMBEDDING_DIR / 'Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt',

'zebrafish': EMBEDDING_DIR / 'Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM2.pt',
"mouse_lemur": EMBEDDING_DIR / "Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM2.pt",
"pig": EMBEDDING_DIR / 'Sus_scrofa.Sscrofa11.1.gene_symbol_to_embedding_ESM2.pt',
Expand All @@ -23,6 +23,8 @@
}
}

extra_species = pd.read_csv("./model_files/new_species_protein_embeddings.csv").set_index("species").to_dict()["path"]
MODEL_TO_SPECIES_TO_GENE_EMBEDDING_PATH["ESM2"].update(extra_species) # adds new species


def load_gene_embeddings_adata(adata: AnnData, species: list, embedding_model: str) -> Tuple[AnnData, Dict[str, torch.FloatTensor]]:
Expand Down
15 changes: 11 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,19 @@ def run_eval(adata, name, pe_idx_path, chroms_path, starts_path, shapes_dict,
output_dim=args.output_dim)
if args.model_loc is None:
raise ValueError("Must provide a model location")

all_pe = get_ESM2_embeddings(args)
all_pe.requires_grad = False
model.pe_embedding = nn.Embedding.from_pretrained(all_pe)
# intialize as empty
empty_pe = torch.zeros(145469, 5120)
empty_pe.requires_grad = False
model.pe_embedding = nn.Embedding.from_pretrained(empty_pe)
model.load_state_dict(torch.load(args.model_loc, map_location="cpu"),
strict=True)
# Load in the real token embeddings
all_pe = get_ESM2_embeddings(args)
# This will make sure that you don't overwrite the tokens in case you're embedding species from the training data
# We avoid doing that just in case the random seeds are different across different versions.
if all_pe.shape[0] != 145469:
all_pe.requires_grad = False
model.pe_embedding = nn.Embedding.from_pretrained(all_pe)
print(f"Loaded model:\n{args.model_loc}")
model = model.eval()
model = accelerator.prepare(model)
Expand Down
2 changes: 2 additions & 0 deletions model_files/new_species_protein_embeddings.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
species,path
chicken,/dfs/project/cross-species/yanay/code/uce_code/UCE_public/model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt

0 comments on commit 3ded79b

Please sign in to comment.