Commit 93e1e761 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #796 from lilleswing/sklearn

Sklearn Models with no weights
parents 21852c42 d6d50d1f
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
from flaky import flaky

import deepchem as dc
from deepchem.models.tensorgraph.layers import Feature, Label, Dense, L2Loss
import numpy as np
@@ -7,6 +9,7 @@ import unittest

class TestMAML(unittest.TestCase):

  @flaky
  def test_sine(self):
    """Test meta-learning for sine function."""

+39 −6
Original line number Diff line number Diff line
@@ -2,9 +2,11 @@
Code for processing datasets using scikit-learn.
"""
import numpy as np
from sklearn.cross_decomposition import PLSRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.linear_model import LogisticRegression, BayesianRidge
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import LassoCV
@@ -14,11 +16,42 @@ from deepchem.models import Model
from deepchem.utils.save import load_from_disk
from deepchem.utils.save import save_to_disk

NON_WEIGHTED_MODELS = [
    LogisticRegression, PLSRegression, GaussianProcessRegressor, ElasticNetCV,
    LassoCV, BayesianRidge
]


class SklearnModel(Model):
  """
  Abstract base class for different ML models.
  """

  def __init__(self,
               model_instance=None,
               model_dir=None,
               verbose=True,
               **kwargs):
    """
    Parameters
    ----------
    model_instance: sklearn model
    model_dir: str
    verbose: bool
    kwargs: dict
      kwargs['use_weights'] is a bool which determines if we pass weights into
      self.model_instance.fit()
    """
    super(SklearnModel, self).__init__(model_instance, model_dir, verbose,
                                       **kwargs)
    if 'use_weights' in kwargs:
      self.use_weights = kwargs['use_weights']
    else:
      self.use_weights = True
    for model_instance in NON_WEIGHTED_MODELS:
      if isinstance(self.model_instance, model_instance):
        self.use_weights = False

  def fit(self, dataset, **kwargs):
    """
    Fits SKLearn model to data.
@@ -27,11 +60,10 @@ class SklearnModel(Model):
    y = np.squeeze(dataset.y)
    w = np.squeeze(dataset.w)
    # Logistic regression doesn't support weights
    if not isinstance(self.model_instance, LogisticRegression):
    if self.use_weights:
      self.model_instance.fit(X, y, w)
    else:
      return
    self.model_instance.fit(X, y)
    y_pred_raw = self.model_instance.predict(X)

  def predict_on_batch(self, X, pad_batch=False):
    """
@@ -73,7 +105,8 @@ class SklearnModel(Model):

  def reload(self):
    """Loads sklearn model from joblib file on disk."""
    self.model_instance = load_from_disk(Model.get_model_filename(self.model_dir))
    self.model_instance = load_from_disk(
        Model.get_model_filename(self.model_dir))

  def get_num_tasks(self):
    """Number of tasks for this model. Defaults to 1"""