Skip to content

Commit

Permalink
Merge pull request #80 from cane11/main
Browse files Browse the repository at this point in the history
cuml training LogisticRegression.
  • Loading branch information
ChuanXu1 authored Aug 11, 2023
2 parents e18b352 + 1b540e0 commit 3252d72
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion celltypist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDClassifier
from sklearn import preprocessing
from sklearn import __version__ as skv
from typing import Optional, Union
from .models import Model
Expand Down Expand Up @@ -122,6 +123,40 @@ def _LRClassifier(indata, labels, C, solver, max_iter, n_jobs, **kwargs) -> Logi
classifier.fit(indata, labels)
return classifier

def _cuLRClassifier(indata, labels, C, solver, max_iter, n_jobs, kwargs_cuml, **kwargs) -> LogisticRegression:
"""
For internal use. Get the logistic Classifier.
"""
try:
from cuml import LogisticRegression as cuLogisticRegression
except ImportError:
print('Install cuml to use GPU version of logistic regression.')
raise ImportError
no_cells = len(labels)
logger.info(f"🏋️ Training data using logistic regression on GPU")
if (no_cells > 100000) and (indata.shape[1] > 10000):
logger.warn(f"⚠️ Warning: it may take a long time to train this dataset with {no_cells} cells and {indata.shape[1]} genes, try to downsample cells and/or restrict genes to a subset (e.g., hvgs)")
classifier_ = cuLogisticRegression(C = C, max_iter = max_iter, **kwargs_cuml)

le = preprocessing.LabelEncoder()
le.fit(labels)
labels_ = le.transform(labels)
classifier_.fit(indata, labels_)

if solver is None:
solver = 'sag' if no_cells>50000 else 'lbfgs'
elif solver not in ('liblinear', 'lbfgs', 'newton-cg', 'sag', 'saga'):
raise ValueError(
f"🛑 Invalid `solver`, should be one of `'liblinear'`, `'lbfgs'`, `'newton-cg'`, `'sag'`, and `'saga'`")
# Hacky solution to allow upload. Copy parameters to sklearn function.
classifier = LogisticRegression(C = C, solver = solver, max_iter = 1, multi_class = 'ovr', n_jobs = n_jobs, **kwargs)
classifier.coef_ = classifier_.coef_
classifier.intercept_ = classifier_.intercept_
classifier.classes_ = le.inverse_transform(classifier_.classes_)
classifier.n_iter_ = max_iter

return classifier

def _SGDClassifier(indata, labels,
alpha, max_iter, n_jobs,
mini_batch, batch_number, batch_size, epochs, balance_cell_type, **kwargs) -> SGDClassifier:
Expand Down Expand Up @@ -170,6 +205,8 @@ def train(X = None,
C: float = 1.0, solver: Optional[str] = None, max_iter: Optional[int] = None, n_jobs: Optional[int] = None,
#SGD param
use_SGD: bool = False, alpha: float = 0.0001,
#GPU param
use_GPU: bool = False,
#mini-batch
mini_batch: bool = False, batch_number: int = 100, batch_size: int = 1000, epochs: int = 10, balance_cell_type: bool = False,
#feature selection
Expand Down Expand Up @@ -320,7 +357,9 @@ def train(X = None,
indata = indata.toarray()
#max_iter
if max_iter is None:
if indata.shape[0] < 50000:
if use_GPU:
max_iter = 1000
elif indata.shape[0] < 50000:
max_iter = 1000
elif indata.shape[0] < 500000:
max_iter = 500
Expand All @@ -329,6 +368,8 @@ def train(X = None,
#classifier
if use_SGD or feature_selection:
classifier = _SGDClassifier(indata = indata, labels = labels, alpha = alpha, max_iter = max_iter, n_jobs = n_jobs, mini_batch = mini_batch, batch_number = batch_number, batch_size = batch_size, epochs = epochs, balance_cell_type = balance_cell_type, **kwargs)
elif use_GPU:
classifier = _cuLRClassifier(indata = indata, labels = labels, C = C, solver = solver, max_iter = max_iter, n_jobs = n_jobs, kwargs_cuml={}, **kwargs)
else:
classifier = _LRClassifier(indata = indata, labels = labels, C = C, solver = solver, max_iter = max_iter, n_jobs = n_jobs, **kwargs)
#feature selection -> new classifier and scaler
Expand All @@ -345,6 +386,8 @@ def train(X = None,
logger.info(f"🏋️ Starting the second round of training")
if use_SGD:
classifier = _SGDClassifier(indata = indata[:, gene_index], labels = labels, alpha = alpha, max_iter = max_iter, n_jobs = n_jobs, mini_batch = mini_batch, batch_number = batch_number, batch_size = batch_size, epochs = epochs, balance_cell_type = balance_cell_type, **kwargs)
elif use_GPU:
classifier = _cuLRClassifier(indata = indata[:, gene_index], labels = labels, C = C, solver = solver, max_iter = max_iter, n_jobs = n_jobs, kwargs_cuml={}, **kwargs)
else:
classifier = _LRClassifier(indata = indata[:, gene_index], labels = labels, C = C, solver = solver, max_iter = max_iter, n_jobs = n_jobs, **kwargs)
scaler.mean_ = scaler.mean_[gene_index]
Expand Down

0 comments on commit 3252d72

Please sign in to comment.