Commit 84a135d0 authored by nd-02110114's avatar nd-02110114
Browse files

generalize XGBoostModel to GDBTModel (support LightGBM)

parent 9ce7a2a4
Loading
Loading
Loading
Loading
+1 −4
Original line number Diff line number Diff line
@@ -42,6 +42,7 @@ install:
script:
  - if [[ "$DOCTEST_EXAMPLES" == "true" ]]; then
      cd docs && pip install -r requirements.txt;
      make clean html;
      make doctest_examples;
      travis_terminate $?;
    fi
@@ -49,10 +50,6 @@ script:
  - bash devtools/run_flake8.sh
  - mypy -p deepchem
  - pytest -v -m "not slow" --cov=deepchem deepchem
  - if [ $TRAVIS_PYTHON_VERSION == '3.7' ]; then
      cd docs && pip install -r requirements.txt;
      make clean html && cd ..;
    fi
  - if [ $TRAVIS_PYTHON_VERSION == '3.7' ]; then
      pytest -v --ignore-glob='deepchem/**/test*.py' --doctest-modules deepchem;
    fi
+10 −7
Original line number Diff line number Diff line
@@ -26,12 +26,7 @@ from deepchem.models.chemnet_models import Smiles2Vec, ChemCeption

# scikit-learn model
from deepchem.models.sklearn_models import SklearnModel

# XGBoost model
try:
  from deepchem.models.xgboost_models import XGBoostModel
except ModuleNotFoundError:
  pass
from deepchem.models.gdbt_models import GDBTModel

# PyTorch models
try:
@@ -41,7 +36,15 @@ try:
except ModuleNotFoundError:
  pass

#################### Compatibility imports for renamed TensorGraph models. Remove below with DeepChem 3.0. ####################
#####################################################################################
# Compatibility imports for renamed XGBoost models. Remove below with DeepChem 3.0.
#####################################################################################

from deepchem.models.gdbt_models.gdbt_model import XGBoostModel

########################################################################################
# Compatibility imports for renamed TensorGraph models. Remove below with DeepChem 3.0.
########################################################################################

from deepchem.models.text_cnn import TextCNNTensorGraph
from deepchem.models.graph_models import WeaveTensorGraph, DTNNTensorGraph, DAGTensorGraph, GraphConvTensorGraph, MPNNTensorGraph
+2 −0
Original line number Diff line number Diff line
# flake8: noqa
from deepchem.models.gdbt_models.gdbt_model import GDBTModel
 No newline at end of file
+154 −0
Original line number Diff line number Diff line
"""
Gradient boosting wrapper interface
"""

import os
import logging
import tempfile
import warnings
from typing import Callable, Optional, Tuple, Union

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split

from deepchem.data import Dataset
from deepchem.models.sklearn_models import SklearnModel

logger = logging.getLogger(__name__)


class GDBTModel(SklearnModel):
  """Wrapper class that wraps GDBT models as DeepChem models.

  This class supports LightGBM/XGBoost models.
  """

  def __init__(self,
               model: BaseEstimator,
               model_dir: Optional[str] = None,
               early_stopping_rounds: int = 50,
               eval_metric: Optional[Union[str, Callable[..., Tuple]]] = None,
               **kwargs):
    """
    Parameters
    ----------
    model: BaseEstimator
      The model instance of scikit-learn wrapper LightGBM/XGBoost models.
    model_dir: str, optional (default None)
      Path to directory where model will be stored.
    early_stopping_rounds: int, optional (default 50)
      Activates early stopping. Validation metric needs to improve at least once
      in every early_stopping_rounds round(s) to continue training.
    eval_metric: Union[str, Callbale]
      If string, it should be a built-in evaluation metric to use.
      If callable, it should be a custom evaluation metric, see official note for more details.
    """
    if model_dir is not None:
      if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    else:
      model_dir = tempfile.mkdtemp()
    self.model_dir = model_dir
    self.model = model
    self.model_class = model.__class__
    self.early_stopping_rounds = early_stopping_rounds
    self.model_type = self._check_model_type()

    if eval_metric is None:
      if self.model_type == 'classification':
        self.eval_metric: Union[str, Callable[..., Tuple]] = 'auc'
      elif self.model_type == 'regression':
        self.eval_metric = 'mae'
    else:
      self.eval_metric = eval_metric

  def _check_model_type(self) -> str:
    class_name = self.model.__class__.__name__
    if class_name.endswith('Classifier'):
      return 'classification'
    elif class_name.endswith('Regressor'):
      return 'regression'
    else:
      raise ValueError(
          '{} is not a supported model instance.'.format(class_name))

  def fit(self, dataset: Dataset):
    """Fits GDBT model with all data.

    First, this function splits all data into train and valid data (8:2),
    and finds the best n_estimators. And then, we retrain all data using
    best n_estimators * 1.25.

    Parameters
    ----------
    dataset: Dataset
      The `Dataset` to train this model on.
    """
    X = dataset.X
    y = np.squeeze(dataset.y)

    # GDBT doesn't support multi-output(task)
    if len(y.shape) != 1:
      raise ValueError("GDBT model doesn't support multi-output(task)")

    seed = self.model.random_state
    stratify = None
    if self.model_type == 'classification':
      stratify = y

    # Find optimal n_estimators based on original learning_rate and early_stopping_rounds
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=seed, stratify=stratify)
    self.model.fit(
        X_train,
        y_train,
        early_stopping_rounds=self.early_stopping_rounds,
        eval_metric=self.eval_metric,
        eval_set=[(X_test, y_test)])

    # retrain model to whole data using best n_estimators * 1.25
    if self.model.__class__.__name__.startswith('XGB'):
      estimated_best_round = np.round(self.model.best_ntree_limit * 1.25)
    else:
      estimated_best_round = np.round(self.model.best_iteration_ * 1.25)
    self.model.n_estimators = np.int64(estimated_best_round)
    self.model.fit(X, y, eval_metric=self.eval_metric)

  def fit_with_eval(self, train_dataset: Dataset, valid_dataset: Dataset):
    """Fits GDBT model with valid data.

    Parameters
    ----------
    train_dataset: Dataset
      The `Dataset` to train this model on.
    valid_dataset: Dataset
      The `Dataset` to validate this model on.
    """
    X_train, X_valid = train_dataset.X, valid_dataset.X
    y_train, y_valid = np.squeeze(train_dataset.y), np.squeeze(valid_dataset.y)

    # GDBT doesn't support multi-output(task)
    if len(y_train.shape) != 1 or len(y_valid.shape) != 1:
      raise ValueError("GDBT model doesn't support multi-output(task)")

    self.model.fit(
        X_train,
        y_train,
        early_stopping_rounds=self.early_stopping_rounds,
        eval_metric=self.eval_metric,
        eval_set=[(X_valid, y_valid)])


#########################################
# Deprecation warnings for XGBoostModel
#########################################


class XGBoostModel(GDBTModel):

  def __init__(self, *args, **kwargs):
    warnings.warn(
        "XGBoostModel is deprecated and has been renamed to GDBTModel.",
        FutureWarning)
    super(XGBoostModel, self).__init__(*args, **kwargs)
+1 −3
Original line number Diff line number Diff line
@@ -169,9 +169,7 @@ class KerasModel(Model):
      like a printout every 10 batch steps, you'd set
      `log_frequency=10` for example.
    """
    super(KerasModel, self).__init__(
        model_instance=model, model_dir=model_dir, **kwargs)
    self.model = model
    super(KerasModel, self).__init__(model=model, model_dir=model_dir, **kwargs)
    if isinstance(loss, Loss):
      self._loss_fn: LossFn = _StandardLoss(model, loss)
    else:
Loading