Skip to content

Commit

Permalink
Fix cuml import
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuanXu1 committed Nov 1, 2023
1 parent b31d8f8 commit 77abaa3
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions celltypist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from . import logger
from scipy.sparse import spmatrix
from datetime import datetime
import sys
try:
from cuml import LogisticRegression as cuLogisticRegression
except ImportError:
pass

def _to_vector(_vector_or_file):
"""
Expand Down Expand Up @@ -310,12 +315,9 @@ def train(X = None,
An instance of the :class:`~celltypist.models.Model` trained by celltypist.
"""
#Test GPU
if not use_SGD and use_GPU:
try:
from cuml import LogisticRegression as cuLogisticRegression
except ImportError:
logger.warn(f"⚠️ Warning: to run logistic regression on GPU, please first install cuml")
return
if not use_SGD and use_GPU and 'cuml' not in sys.modules:
logger.warn(f"⚠️ Warning: to run logistic regression on GPU, please first install cuml")
return
#prepare
logger.info("🍳 Preparing data before training")
indata, labels, genes = _prepare_data(X, labels, genes, transpose_input)
Expand Down

0 comments on commit 77abaa3

Please sign in to comment.