Skip to content

Commit

Permalink
Detect top row format for prediction and training
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuanXu1 committed Aug 12, 2023
1 parent c62ca37 commit dd7e6a5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions celltypist/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __init__(self, filename: Union[AnnData,str] = "", model: Union[Model,str] =
f"🛑 The number of cells in {cell_file} does not match the number of cells in {self.filename}")
self.adata.var_names = genes_mtx
self.adata.obs_names = cells_mtx
if not float(self.adata.X.max()).is_integer():
if not float(self.adata.X[:1000].max()).is_integer():
logger.warn(f"⚠️ Warning: the input file seems not a raw count matrix. The prediction result may not be accurate")
if (self.adata.n_vars >= 100000) or (len(self.adata.var_names[0]) >= 30) or (len(self.adata.obs_names.intersection(['GAPDH', 'ACTB', 'CALM1', 'PTPRC', 'MALAT1'])) >= 1):
logger.warn(f"⚠️ The input matrix is detected to be a gene-by-cell matrix, will transpose it")
Expand All @@ -301,7 +301,7 @@ def __init__(self, filename: Union[AnnData,str] = "", model: Union[Model,str] =
elif isinstance(filename, AnnData) or (isinstance(filename, str) and filename.endswith('.h5ad')):
self.adata = sc.read(filename) if isinstance(filename, str) else filename
self.adata.var_names_make_unique()
if (self.adata.X.min() < 0) or (self.adata.X.max() > np.log1p(10000)):
if (self.adata.X[:1000].min() < 0) or (self.adata.X[:1000].max() > np.log1p(10000)):
logger.info("👀 Invalid expression matrix in `.X`, expect log1p normalized expression to 10000 counts per cell; will try the `.raw` attribute")
try:
self.indata = self.adata.raw.X
Expand All @@ -310,7 +310,7 @@ def __init__(self, filename: Union[AnnData,str] = "", model: Union[Model,str] =
except Exception as e:
raise Exception(
f"🛑 Fail to use the `.raw` attribute in the input object. {e}")
if (self.indata.min() < 0) or (self.indata.max() > np.log1p(10000)):
if (self.indata[:1000].min() < 0) or (self.indata[:1000].max() > np.log1p(10000)):
raise ValueError(
"🛑 Invalid expression matrix in both `.X` and `.raw.X`, expect log1p normalized expression to 10000 counts per cell")
else:
Expand Down Expand Up @@ -391,7 +391,7 @@ def _construct_neighbor_graph(adata: AnnData) -> tuple:
adata.uns['log1p']['base'] = None

if 'X_pca' not in adata.obsm.keys():
if adata.X.min() < 0:
if adata.X[:1000].min() < 0:
adata = adata.raw.to_adata()
if 'highly_variable' not in adata.var:
sc.pp.filter_genes(adata, min_cells=5)
Expand Down
4 changes: 2 additions & 2 deletions celltypist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _prepare_data(X, labels, genes, transpose) -> tuple:
if isinstance(X, AnnData) or (isinstance(X, str) and X.endswith('.h5ad')):
adata = sc.read(X) if isinstance(X, str) else X
adata.var_names_make_unique()
if adata.X.min() < 0:
if adata.X[:1000].min() < 0:
logger.info("👀 Detected scaled expression in the input data, will try the .raw attribute")
try:
indata = adata.raw.X
Expand Down Expand Up @@ -80,7 +80,7 @@ def _prepare_data(X, labels, genes, transpose) -> tuple:
f"🛑 The number of genes provided does not match the number of genes in {X}")
adata.var_names = np.array(genes)
adata.var_names_make_unique()
if not float(adata.X.max()).is_integer():
if not float(adata.X[:1000].max()).is_integer():
logger.warn(f"⚠️ Warning: the input file seems not a raw count matrix. The trained model may be biased")
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
Expand Down

0 comments on commit dd7e6a5

Please sign in to comment.