Commit d6d50d1f authored by Karl Leswing's avatar Karl Leswing
Browse files

PR comments

parent b00c4111
Loading
Loading
Loading
Loading
+19 −4
Original line number Diff line number Diff line
@@ -16,10 +16,10 @@ from deepchem.models import Model
from deepchem.utils.save import load_from_disk
from deepchem.utils.save import save_to_disk

NON_WEIGHTED_MODELS = {
NON_WEIGHTED_MODELS = [
    LogisticRegression, PLSRegression, GaussianProcessRegressor, ElasticNetCV,
    LassoCV, BayesianRidge
}
]


class SklearnModel(Model):
@@ -27,8 +27,23 @@ class SklearnModel(Model):
  Abstract base class for different ML models.
  """

  def __init__(self, **kwargs):
    super(SklearnModel, self).__init__(**kwargs)
  def __init__(self,
               model_instance=None,
               model_dir=None,
               verbose=True,
               **kwargs):
    """
    Parameters
    ----------
    model_instance: sklearn model
    model_dir: str
    verbose: bool
    kwargs: dict
      kwargs['use_weights'] is a bool which determines if we pass weights into
      self.model_instance.fit()
    """
    super(SklearnModel, self).__init__(model_instance, model_dir, verbose,
                                       **kwargs)
    if 'use_weights' in kwargs:
      self.use_weights = kwargs['use_weights']
    else: