Skip to content

Commit bc6cbe4

Browse files
authored
Merge ee95724 into 61c0a95
2 parents 61c0a95 + ee95724 commit bc6cbe4

File tree

8 files changed

+467
-2
lines changed

8 files changed

+467
-2
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ singularity/*
1818
*.arrow
1919
*zip
2020
*.npy
21-
21+
*.json
2222
*.pickle
2323
*.pkl
2424
*.bin

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ A lot of our models have been published by talend authors developing these excit
147147
- [scanpy](https://github.com/scverse/scanpy)
148148
- [transformers](https://github.com/huggingface/transformers)
149149
- [scikit-learn](https://github.com/scikit-learn/scikit-learn)
150+
- [GenePT](https://github.com/yiqunchen/GenePT)
151+
- [Caduceus](https://github.com/kuleshov-group/caduceus)
150152

151153
### Licenses
152154

examples/notebooks/Genegpt-sample-run.ipynb

+230
Large diffs are not rendered by default.

helical/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def filter(self, record):
2929
from .models.uce.fine_tuning_model import UCEFineTuningModel
3030
from .models.geneformer.model import Geneformer,GeneformerConfig
3131
from .models.geneformer.fine_tuning_model import GeneformerFineTuningModel
32+
from .models.genept.model import GenePT,GenePTConfig
3233
from .models.scgpt.model import scGPT, scGPTConfig
3334
from .models.scgpt.fine_tuning_model import scGPTFineTuningModel
3435
from .models.hyena_dna.model import HyenaDNA, HyenaDNAConfig

helical/models/genept/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import GenePT,GenePTConfig
+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Optional
2+
from pathlib import Path
3+
from helical.constants.paths import CACHE_DIR_HELICAL
4+
from typing import Literal
5+
6+
class GenePTConfig():
7+
"""Configuration class to use the GenePT Model.
8+
9+
Parameters
10+
----------
11+
model_name : Literal["gpt3.5"], optional, default="gpt3.5"
12+
The name of the model for the embeddings.
13+
batch_size : int, optional, default = 24
14+
The batch size
15+
emb_layer : int, optional, default = -1
16+
The embedding layer
17+
emb_mode : Literal["cls", "cell", "gene"], optional, default="cell"
18+
The embedding mode
19+
device : Literal["cpu", "cuda"], optional, default="cpu"
20+
The device to use. Either use "cuda" or "cpu".
21+
accelerator : bool, optional, default=False
22+
The accelerator configuration. By default same device as model.
23+
nproc: int, optional, default=1
24+
Number of processes to use for data processing.
25+
custom_attr_name_dict : dict, optional, default=None
26+
A dictionary that contains the names of the custom attributes to be added to the dataset.
27+
The keys of the dictionary are the names of the custom attributes, and the values are the names of the columns in adata.obs.
28+
For example, if you want to add a custom attribute called "cell_type" to the dataset, you would pass custom_attr_name_dict = {"cell_type": "cell_type"}.
29+
If you do not want to add any custom attributes, you can leave this parameter as None.
30+
Returns
31+
-------
32+
GenePTConfig
33+
The GenePT configuration object
34+
35+
"""
36+
def __init__(
37+
self,
38+
model_name: Literal["gpt3.5"] = "gpt3.5",
39+
batch_size: int = 24,
40+
emb_layer: int = -1,
41+
emb_mode: Literal["cls", "cell", "gene"] = "cell",
42+
device: Literal["cpu", "cuda"] = "cpu",
43+
accelerator: Optional[bool] = False,
44+
nproc: int = 1,
45+
custom_attr_name_dict: Optional[dict] = None
46+
):
47+
48+
# model specific parameters
49+
self.model_map = {
50+
"gpt3.5": {
51+
'input_size': 4096,
52+
'special_token': True,
53+
'embsize': 512,
54+
}
55+
56+
}
57+
if model_name not in self.model_map:
58+
raise ValueError(f"Model name {model_name} not found in available models: {self.model_map.keys()}")
59+
list_of_files_to_download = [
60+
"genept/genept_embeddings/genept_embeddings.json",
61+
]
62+
63+
embeddings_path = Path(CACHE_DIR_HELICAL, 'genept/genept_embeddings/genept_embeddings.json')
64+
65+
self.config = {
66+
"embeddings_path": embeddings_path,
67+
"model_name": model_name,
68+
"batch_size": batch_size,
69+
"emb_layer": emb_layer,
70+
"emb_mode": emb_mode,
71+
"device": device,
72+
"accelerator": accelerator,
73+
"input_size": self.model_map[model_name]["input_size"],
74+
"special_token": self.model_map[model_name]["special_token"],
75+
"embsize": self.model_map[model_name]["embsize"],
76+
"nproc": nproc,
77+
"custom_attr_name_dict": custom_attr_name_dict,
78+
"list_of_files_to_download": list_of_files_to_download
79+
}
80+
81+
82+

helical/models/genept/model.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from helical.models.base_models import HelicalRNAModel
2+
import logging
3+
import numpy as np
4+
from anndata import AnnData
5+
from helical.utils.downloader import Downloader
6+
from helical.models.genept.genept_config import GenePTConfig
7+
from helical.utils.mapping import map_ensembl_ids_to_gene_symbols
8+
import logging
9+
import scanpy as sc
10+
import torch
11+
import json
12+
import torch
13+
14+
LOGGER = logging.getLogger(__name__)
15+
class GenePT(HelicalRNAModel):
16+
"""GenePT Model.
17+
18+
```
19+
20+
Parameters
21+
----------
22+
configurer : GenePTConfig, optional, default = default_configurer
23+
The model configuration
24+
25+
Notes
26+
-----
27+
28+
29+
"""
30+
default_configurer = GenePTConfig()
31+
def __init__(self, configurer: GenePTConfig = default_configurer):
32+
super().__init__()
33+
self.configurer = configurer
34+
self.config = configurer.config
35+
36+
downloader = Downloader()
37+
for file in self.config["list_of_files_to_download"]:
38+
downloader.download_via_name(file)
39+
40+
with open(self.config['embeddings_path'],"r") as f:
41+
self.embeddings = json.load(f)
42+
43+
LOGGER.info("GenePT initialized successfully.")
44+
45+
def process_data(self,
46+
adata: AnnData,
47+
gene_names: str = "index",
48+
use_raw_counts: bool = True,
49+
) -> AnnData:
50+
"""
51+
Processes the data for the GenePT model.
52+
53+
Parameters
54+
----------
55+
adata : AnnData
56+
The AnnData object containing the data to be processed. GenePT uses Ensembl IDs to identify genes
57+
and currently supports only human genes. If the AnnData object already has an 'ensembl_id' column,
58+
the mapping step can be skipped.
59+
gene_names : str, optional, default="index"
60+
The column in `adata.var` that contains the gene names. If set to a value other than "ensembl_id",
61+
the gene symbols in that column will be mapped to Ensembl IDs using the 'pyensembl' package,
62+
which retrieves mappings from the Ensembl FTP server and loads them into a local database.
63+
- If set to "index", the index of the AnnData object will be used and mapped to Ensembl IDs.
64+
- If set to "ensembl_id", no mapping will occur.
65+
Special case:
66+
If the index of `adata` already contains Ensembl IDs, setting this to "index" will result in
67+
invalid mappings. In such cases, create a new column containing Ensembl IDs and pass "ensembl_id"
68+
as the value of `gene_names`.
69+
use_raw_counts : bool, optional, default=True
70+
Determines whether raw counts should be used.
71+
72+
Returns
73+
-------
74+
Dataset
75+
The tokenized dataset in the form of a Huggingface Dataset object.
76+
"""
77+
LOGGER.info(f"Processing data for GenePT.")
78+
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)
79+
80+
# map gene symbols to ensemble ids if provided
81+
if gene_names == "ensembl_id":
82+
if (adata.var[gene_names].str.startswith("ENS").all()) or (adata.var[gene_names].str.startswith("None").any()):
83+
message = "It seems an anndata with 'ensemble ids' and/or 'None' was passed. " \
84+
"Please set gene_names='ensembl_id' and remove 'None's to skip mapping."
85+
LOGGER.info(message)
86+
raise ValueError(message)
87+
adata = map_ensembl_ids_to_gene_symbols(adata, gene_names)
88+
89+
n_top_genes = 1000
90+
LOGGER.info(f"Filtering the top {n_top_genes} highly variable genes.")
91+
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor='seurat_v3')
92+
sc.pp.normalize_total(adata, target_sum=1e4)
93+
sc.pp.log1p(adata)
94+
95+
genes_names = adata.var_names[adata.var['highly_variable']].tolist()
96+
adata = adata[:,genes_names]
97+
98+
LOGGER.info(f"Successfully processed the data for GenePT.")
99+
return adata
100+
101+
def get_text_embeddings(self, dataset: AnnData) -> np.array:
102+
"""Gets the gene embeddings from the GenePT model
103+
104+
Parameters
105+
----------
106+
dataset : AnnData
107+
The tokenized dataset containing the processed data
108+
109+
Returns
110+
-------
111+
np.array
112+
The gene embeddings in the form of a numpy array
113+
"""
114+
# Generate a response
115+
raw_embeddings = dataset.var_names
116+
weights = []
117+
count_missed = 0
118+
gene_list = []
119+
for i,emb in enumerate(raw_embeddings):
120+
gene = self.embeddings.get(emb.upper(),None)
121+
if gene is not None:
122+
weights.append(gene['embeddings'])
123+
gene_list.append(emb)
124+
else:
125+
count_missed += 1
126+
LOGGER.info("Couln't find {} genes in embeddings".format(count_missed))
127+
128+
weights = torch.Tensor(weights)
129+
embeddings = torch.matmul(torch.Tensor(dataset[:,gene_list].X.toarray()),weights)
130+
return embeddings
131+
132+
def get_embeddings(self, dataset: AnnData) -> torch.Tensor:
133+
"""Gets the gene embeddings from the GenePT model
134+
135+
Parameters
136+
----------
137+
dataset : Dataset
138+
The tokenized dataset containing the processed data
139+
140+
Returns
141+
-------
142+
np.array
143+
The gene embeddings in the form of a numpy array
144+
"""
145+
LOGGER.info(f"Inference started:")
146+
# Generate a response
147+
embeddings = self.get_text_embeddings(dataset)
148+
embeddings = (embeddings/(np.linalg.norm(embeddings,axis=1)).reshape(-1,1))
149+
return embeddings

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "helical"
7-
version = "0.0.1a18"
7+
version = "0.0.1a19"
88
authors = [
99
{ name="Helical Team", email="[email protected]" },
1010
]

0 commit comments

Comments
 (0)