Unverified Commit f272f7cb authored by Daiki Nishikawa's avatar Daiki Nishikawa Committed by GitHub
Browse files

Merge pull request #2215 from nd-02110114/sklearn-model

Generalize XGBoostModel to GBDTModel (support LightGBM)
parents 02fce357 09828fbd
Loading
Loading
Loading
Loading

.python-version

deleted100644 → 0
+0 −1
Original line number Diff line number Diff line
miniconda3-latest
+9 −11
Original line number Diff line number Diff line
jobs:
  include:
    - name: Python 3.6
    - name: Linux Python 3.6
      language: python
      python: '3.6'
      sudo: required
      dist: xenial
    - name: Python 3.7
    - name: Linux Python 3.7
      language: python
      python: '3.7'
      sudo: required
      dist: xenial
    - name: Windows
      env: NIGHTLY_PKG_PUBLISH=true
    - name: Windows Python 3.7
      language: c
      python: '3.7'
      os: windows
    - name: DocTest Examples
    - name: Documentation
      language: python
      python: '3.7'
      sudo: required
      dist: xenial
      env: DOCTEST_EXAMPLES=true
      env: CHECK_ONLY_DOCS=true
cache: pip
install:
  - if [[ "$TRAVIS_OS_NAME" != "windows" ]]; then
@@ -40,8 +41,9 @@ install:
  - conda activate deepchem
  - pip install -e .
script:
  - if [[ "$DOCTEST_EXAMPLES" == "true" ]]; then
  - if [[ "$CHECK_ONLY_DOCS" == "true" ]]; then
      cd docs && pip install -r requirements.txt;
      make clean html;
      make doctest_examples;
      travis_terminate $?;
    fi
@@ -49,10 +51,6 @@ script:
  - bash devtools/run_flake8.sh
  - mypy -p deepchem
  - pytest -v -m "not slow" --cov=deepchem deepchem
  - if [ $TRAVIS_PYTHON_VERSION == '3.7' ]; then
      cd docs && pip install -r requirements.txt;
      make clean html && cd ..;
    fi
  - if [ $TRAVIS_PYTHON_VERSION == '3.7' ]; then
      pytest -v --ignore-glob='deepchem/**/test*.py' --doctest-modules deepchem;
    fi
@@ -66,4 +64,4 @@ deploy:
    secure: b67LO8VZcoKEWo7gDlFdjS1yKUavCt578uAuXPyW6f+e+Tk/sEQRdkx1VYoZlQdfZQo8u4q+E3W184T+/j6ht65/cdy/HYH57LCQySjF/MY2M9+/lcP45aY7Z0F2QHeY9QgpRc8gKthGzgM/bHj2glxlEvT1diItEEoGqE2x/fw1K25cNOni08E4hqz0HPY1SXVwd8/9Z/t1YasrBcOjtJ8kcbyjnmeyhjfkaV/aTaAzuqh2MlqZTSz3dhwsBrZfZp86+8T2TgcoDSuIxCwb777QKW1QlvNyLEKlnfateKMYqrrP65oHrxXEEcHd/N3IH28Bz9wVnENjHLkGJ0vXyXyEWcJFe+V6T0k/8NkZamU4SZE5BM4v6mOdThs4l54vuFajctHDeGgIDjL55MfkDmkKd5lAvlWPwrdw8DERsmqetUfZ/TG7FE6/MT1puu2ffu3A9Ivcch5T46pojIggDWHHn9hUsc6iD3Ov7rVqd024Lzm9V8wXiDYU9EMqAu5lJQRIOO/hnr8Gn6zYRCE1n29MKuNJwauSHfdV/mBTRyOjZyWHSGNaiPw2hqE3tZrrIN4koEYaZiERRVnmVt8wMUTj40YglosTHYpL91SkDH/ResX1rtHKs4Am+R+MmcWULTUQ7UwEtqlsa3nVxTK9gfmJ0nX8Jhjtl2iRhVg5PP8=
  edge: true
  on:
    condition: $TRAVIS_OS_NAME = linux && $TRAVIS_PYTHON_VERSION = 3.7
    condition: $NIGHTLY_PKG_PUBLISH = true
+10 −7
Original line number Diff line number Diff line
@@ -26,12 +26,7 @@ from deepchem.models.chemnet_models import Smiles2Vec, ChemCeption

# scikit-learn model
from deepchem.models.sklearn_models import SklearnModel

# XGBoost model
try:
  from deepchem.models.xgboost_models import XGBoostModel
except ModuleNotFoundError:
  pass
from deepchem.models.gbdt_models import GBDTModel

# PyTorch models
try:
@@ -41,7 +36,15 @@ try:
except ModuleNotFoundError:
  pass

#################### Compatibility imports for renamed TensorGraph models. Remove below with DeepChem 3.0. ####################
#####################################################################################
# Compatibility imports for renamed XGBoost models. Remove below with DeepChem 3.0.
#####################################################################################

from deepchem.models.gbdt_models.gbdt_model import XGBoostModel

########################################################################################
# Compatibility imports for renamed TensorGraph models. Remove below with DeepChem 3.0.
########################################################################################

from deepchem.models.text_cnn import TextCNNTensorGraph
from deepchem.models.graph_models import WeaveTensorGraph, DTNNTensorGraph, DAGTensorGraph, GraphConvTensorGraph, MPNNTensorGraph
+2 −0
Original line number Diff line number Diff line
# flake8: noqa
from deepchem.models.gbdt_models.gbdt_model import GBDTModel
 No newline at end of file
+158 −0
Original line number Diff line number Diff line
"""
Gradient Boosting Decision Tree wrapper interface
"""

import os
import logging
import tempfile
import warnings
from typing import Callable, Optional, Union

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split

from deepchem.data import Dataset
from deepchem.models.sklearn_models import SklearnModel

logger = logging.getLogger(__name__)


class GBDTModel(SklearnModel):
  """Wrapper class that wraps GBDT models as DeepChem models.

  This class supports LightGBM/XGBoost models.
  """

  def __init__(self,
               model: BaseEstimator,
               model_dir: Optional[str] = None,
               early_stopping_rounds: int = 50,
               eval_metric: Optional[Union[str, Callable]] = None,
               **kwargs):
    """
    Parameters
    ----------
    model: BaseEstimator
      The model instance of scikit-learn wrapper LightGBM/XGBoost models.
    model_dir: str, optional (default None)
      Path to directory where model will be stored.
    early_stopping_rounds: int, optional (default 50)
      Activates early stopping. Validation metric needs to improve at least once
      in every early_stopping_rounds round(s) to continue training.
    eval_metric: Union[str, Callbale]
      If string, it should be a built-in evaluation metric to use.
      If callable, it should be a custom evaluation metric, see official note for more details.
    """
    if model_dir is not None:
      if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    else:
      model_dir = tempfile.mkdtemp()
    self.model_dir = model_dir
    self.model = model
    self.model_class = model.__class__
    self.early_stopping_rounds = early_stopping_rounds
    self.model_type = self._check_model_type()

    if eval_metric is None:
      if self.model_type == 'classification':
        self.eval_metric: Optional[Union[str, Callable]] = 'auc'
      elif self.model_type == 'regression':
        self.eval_metric = 'mae'
      else:
        self.eval_metric = eval_metric
    else:
      self.eval_metric = eval_metric

  def _check_model_type(self) -> str:
    class_name = self.model.__class__.__name__
    if class_name.endswith('Classifier'):
      return 'classification'
    elif class_name.endswith('Regressor'):
      return 'regression'
    elif class_name == 'NoneType':
      return 'none'
    else:
      raise ValueError(
          '{} is not a supported model instance.'.format(class_name))

  def fit(self, dataset: Dataset):
    """Fits GDBT model with all data.

    First, this function splits all data into train and valid data (8:2),
    and finds the best n_estimators. And then, we retrain all data using
    best n_estimators * 1.25.

    Parameters
    ----------
    dataset: Dataset
      The `Dataset` to train this model on.
    """
    X = dataset.X
    y = np.squeeze(dataset.y)

    # GDBT doesn't support multi-output(task)
    if len(y.shape) != 1:
      raise ValueError("GDBT model doesn't support multi-output(task)")

    seed = self.model.random_state
    stratify = None
    if self.model_type == 'classification':
      stratify = y

    # Find optimal n_estimators based on original learning_rate and early_stopping_rounds
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=seed, stratify=stratify)
    self.model.fit(
        X_train,
        y_train,
        early_stopping_rounds=self.early_stopping_rounds,
        eval_metric=self.eval_metric,
        eval_set=[(X_test, y_test)])

    # retrain model to whole data using best n_estimators * 1.25
    if self.model.__class__.__name__.startswith('XGB'):
      estimated_best_round = np.round(self.model.best_ntree_limit * 1.25)
    else:
      estimated_best_round = np.round(self.model.best_iteration_ * 1.25)
    self.model.n_estimators = np.int64(estimated_best_round)
    self.model.fit(X, y, eval_metric=self.eval_metric)

  def fit_with_eval(self, train_dataset: Dataset, valid_dataset: Dataset):
    """Fits GDBT model with valid data.

    Parameters
    ----------
    train_dataset: Dataset
      The `Dataset` to train this model on.
    valid_dataset: Dataset
      The `Dataset` to validate this model on.
    """
    X_train, X_valid = train_dataset.X, valid_dataset.X
    y_train, y_valid = np.squeeze(train_dataset.y), np.squeeze(valid_dataset.y)

    # GDBT doesn't support multi-output(task)
    if len(y_train.shape) != 1 or len(y_valid.shape) != 1:
      raise ValueError("GDBT model doesn't support multi-output(task)")

    self.model.fit(
        X_train,
        y_train,
        early_stopping_rounds=self.early_stopping_rounds,
        eval_metric=self.eval_metric,
        eval_set=[(X_valid, y_valid)])


#########################################
# Deprecation warnings for XGBoostModel
#########################################


class XGBoostModel(GBDTModel):

  def __init__(self, *args, **kwargs):
    warnings.warn(
        "XGBoostModel is deprecated and has been renamed to GBDTModel.",
        FutureWarning)
    super(XGBoostModel, self).__init__(*args, **kwargs)
Loading