Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharath Ramsundar authored and Bharath Ramsundar committed Mar 14, 2020
1 parent 894dba2 commit 5f1e288
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 8 deletions.
Empty file added torchchem/feat/__init__.py
Empty file.
File renamed without changes.
Empty file.
128 changes: 121 additions & 7 deletions torchchem/models/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,131 @@
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 13 22:31:24 2017
@author: Zhenqin Wu
This file implements the base model for torchchem classes
"""

import shutil
import torch
import time
import numpy as np
from deepchem.trans import undo_transforms
import tempfile
#from deepchem.trans import undo_transforms
from deepchem.utils.save import log
from deepchem.models import Model

class Model(object):
"""
Abstract base class for different ML models.
"""

def __init__(self,
model_instance=None,
model_dir=None,
verbose=True,
**kwargs):
"""Abstract class for all models.
Parameters:
-----------
model_instance: object
Wrapper around ScikitLearn/Keras/Tensorflow model object.
model_dir: str
Path to directory where model will be stored.
"""
self.model_dir_is_temp = False
if model_dir is not None:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
else:
model_dir = tempfile.mkdtemp()
self.model_dir_is_temp = True
self.model_dir = model_dir
self.model_instance = model_instance
self.model_class = model_instance.__class__

self.verbose = verbose

def __del__(self):
if 'model_dir_is_temp' in dir(self) and self.model_dir_is_temp:
shutil.rmtree(self.model_dir)

def fit_on_batch(self, X, y, w):
"""
Updates existing model with new information.
"""
raise NotImplementedError(
"Each model is responsible for its own fit_on_batch method.")

def predict_on_batch(self, X, **kwargs):
"""
Makes predictions on given batch of new data.
Parameters
----------
X: np.ndarray
Features
"""
raise NotImplementedError(
"Each model is responsible for its own predict_on_batch method.")

def reload(self):
"""
Reload trained model from disk.
"""
raise NotImplementedError(
"Each model is responsible for its own reload method.")

def save(self):
"""Dispatcher function for saving.
Each subclass is responsible for overriding this method.
"""
raise NotImplementedError

def fit(self, dataset, nb_epoch=10, batch_size=50, **kwargs):
"""
Fits a model on data in a Dataset object.
"""
raise NotImplementedError

def predict(self, dataset, transformers=[], batch_size=None):
"""
Uses self to make predictions on provided Dataset object.
Returns:
y_pred: numpy ndarray of shape (n_samples,)
"""
raise NotImplementedError

def evaluate(self, dataset, metrics, transformers=[], per_task_metrics=False):
"""
Evaluates the performance of this model on specified dataset.
Parameters
----------
dataset: dc.data.Dataset
Dataset object.
metric: deepchem.metrics.Metric
Evaluation metric
transformers: list
List of deepchem.transformers.Transformer
per_task_metrics: bool
If True, return per-task scores.
Returns
-------
dict
Maps tasks to scores under metric.
"""
raise NotImplementedError

def get_task_type(self):
"""
Currently models can only be classifiers or regressors.
"""
raise NotImplementedError

def get_num_tasks(self):
"""
Get number of tasks.
"""
raise NotImplementedError


class MultitaskModel(Model):
Expand Down Expand Up @@ -153,7 +268,6 @@ def fit(self,
for ind, (X_b, y_b, w_b, ids_b) in enumerate(
# Turns out there are valid cases where we don't want pad-batches
# on by default.
#dataset.iterbatches(batch_size, pad_batches=True)):
dataset.iterbatches(self.batch_size, pad_batches=self.pad_batches)):
if ind % log_every_N_batches == 0:
log("On batch %d" % ind, self.verbose)
Expand Down
2 changes: 1 addition & 1 deletion torchchem/models/multitask_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import numpy as np
from deepchem.metrics import from_one_hot
from torchchem.utils import from_one_hot
from torchchem.models.model import MultitaskModel


Expand Down
12 changes: 12 additions & 0 deletions torchchem/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
A utilities file with useful utilities.
"""
import numpy as np

def from_one_hot(y, axis=1):
"""Transorms label vector from one-hot encoding.
y: np.ndarray
A vector of shape [n_samples, num_classes]
"""
return np.argmax(y, axis=axis)

0 comments on commit 5f1e288

Please sign in to comment.