Commit 49ae1606 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #483 from haozhenWu/xgboostModel

[MoleculeNet] XGBoost models
parents 757c5913 f6402cd1
Loading
Loading
Loading
Loading
+58 −20
Original line number Diff line number Diff line
@@ -215,36 +215,43 @@ Index splitting
|-----------|--------------------|-------------------|-------------------|
|tox21      |logistic regression |0.903              |0.705              |
|           |Random Forest       |0.999              |0.733              |
|           |XGBoost             |0.891              |0.753              |
|           |IRV                 |0.811              |0.767              |
|           |Multitask network   |0.856              |0.763              |
|           |robust MT-NN        |0.857              |0.767              |
|           |graph convolution   |0.872              |0.798              |
|muv        |logistic regression |0.963              |0.766              |
|           |XGBoost             |0.895              |0.714              |
|           |Multitask network   |0.904              |0.764              |
|           |robust MT-NN        |0.934              |0.781              |
|           |graph convolution   |0.840              |0.823              |
|pcba       |logistic regression |0.809              |0.776              |
|           |XGBoost             |0.931              |0.847              |
|           |Multitask network   |0.826              |0.802              |
|           |robust MT-NN        |0.809              |0.783              |
|           |graph convolution   |0.876              |0.852              |
|sider      |logistic regression |0.933              |0.620              |
|           |Random Forest       |0.999              |0.670              |
|           |XGBoost             |0.829              |0.639              |
|           |IRV                 |0.649              |0.642              |
|           |Multitask network   |0.775              |0.634              |
|           |robust MT-NN        |0.803              |0.632              |
|           |graph convolution   |0.708              |0.594              |
|toxcast    |logistic regression |0.721              |0.575              |
|           |XGBoost             |0.738              |0.621              |
|           |Multitask network   |0.830              |0.678              |
|           |robust MT-NN        |0.825              |0.680              |
|           |graph convolution   |0.821              |0.720              |
|clintox    |logistic regression |0.967              |0.676              |
|           |Random Forest       |0.995              |0.776              |
|           |XGBoost             |0.879              |0.890              |
|           |IRV                 |0.763              |0.814              |
|           |Multitask network   |0.934              |0.830              |
|           |robust MT-NN        |0.949              |0.827              |
|           |graph convolution   |0.946              |0.860              |
|hiv        |logistic regression |0.864              |0.739              |
|           |Random Forest       |0.999              |0.720              |
|           |XGBoost             |0.917              |0.745              |
|           |IRV                 |0.841              |0.724              |
|           |Multitask network   |0.761              |0.652              |
|           |robust MT-NN        |0.780              |0.708              |
@@ -256,11 +263,13 @@ Random splitting
|-----------|--------------------|-------------------|-------------------|
|tox21      |logistic regression |0.902              |0.715              |
|           |Random Forest       |0.999              |0.764              |
|           |XGBoost             |0.874              |0.773              |
|           |IRV                 |0.808              |0.767              |
|           |Multitask network   |0.844              |0.795              |
|           |robust MT-NN        |0.855              |0.773              |
|           |graph convolution   |0.865              |0.827              |
|muv        |logistic regression |0.957              |0.719              |
|           |XGBoost             |0.874              |0.696              |
|           |Multitask network   |0.902              |0.734              |
|           |robust MT-NN        |0.933              |0.732              |
|           |graph convolution   |0.860              |0.730              |
@@ -270,22 +279,26 @@ Random splitting
|           |graph convolution   |0.872       	     |0.844              |
|sider      |logistic regression |0.929        	     |0.656              |
|           |Random Forest       |0.999              |0.665              |
|           |XGBoost             |0.824              |0.635              |
|           |IRV                 |0.648              |0.596              |
|           |Multitask network   |0.777        	     |0.655              |
|           |robust MT-NN        |0.804              |0.630              |
|           |graph convolution   |0.705        	     |0.618              |
|toxcast    |logistic regression |0.725        	     |0.586              |
|           |XGBoost             |0.738              |0.633              |
|           |Multitask network   |0.836        	     |0.684              |
|           |robust MT-NN        |0.822              |0.681              |
|           |graph convolution   |0.820        	     |0.717              |
|clintox    |logistic regression |0.972              |0.725              |
|           |Random Forest       |0.997              |0.670              |
|           |XGBoost             |0.886              |0.731              |
|           |IRV                 |0.809              |0.846              |
|           |Multitask network   |0.951              |0.834              |
|           |robust MT-NN        |0.959              |0.830              |
|           |graph convolution   |0.975              |0.876              |
|hiv        |logistic regression |0.860              |0.806              |
|           |Random Forest       |0.999              |0.850              |
|           |XGBoost             |0.933              |0.841              |
|           |IRV                 |0.839              |0.809              |
|           |Multitask network   |0.742              |0.715              |
|           |robust MT-NN        |0.753              |0.727              |
@@ -297,11 +310,13 @@ Scaffold splitting
|-----------|--------------------|-------------------|-------------------|
|tox21      |logistic regression |0.900              |0.650              |
|           |Random Forest       |0.999              |0.629              |
|           |XGBoost             |0.881              |0.703              |
|           |IRV                 |0.823              |0.708              |
|           |Multitask network   |0.863              |0.703              |
|           |robust MT-NN        |0.861              |0.710              |
|           |graph convolution   |0.885              |0.732              |
|muv        |logistic regression |0.947              |0.767              |
|           |XGBoost             |0.875              |0.705              |
|           |Multitask network   |0.899              |0.762              |
|           |robust MT-NN        |0.944              |0.726              |
|           |graph convolution   |0.872              |0.795              |
@@ -311,22 +326,26 @@ Scaffold splitting
|           |graph convolution   |0.874              |0.817              |
|sider      |logistic regression |0.926              |0.592              |
|           |Random Forest       |0.999              |0.619              |
|           |XGBoost             |0.796              |0.560              |
|           |IRV                 |0.639              |0.599              |
|           |Multitask network   |0.776              |0.557              |
|           |robust MT-NN        |0.797              |0.560              |
|           |graph convolution   |0.722              |0.583              |
|toxcast    |logistic regression |0.716              |0.492              |
|           |XGBoost             |0.741              |0.587              |
|           |Multitask network   |0.828              |0.617              |
|           |robust MT-NN        |0.830              |0.614              |
|           |graph convolution   |0.832              |0.638              |
|clintox    |logistic regression |0.960              |0.803              |
|           |Random Forest       |0.993              |0.735              |
|           |XGBoost             |0.873              |0.850              |
|           |IRV                 |0.793              |0.718              |
|           |Multitask network   |0.947              |0.862              |
|           |robust MT-NN        |0.953              |0.890              |
|           |graph convolution   |0.957              |0.823              |
|hiv        |logistic regression |0.858              |0.798              |
|           |Random Forest       |0.946              |0.562              |
|           |XGBoost             |0.927              |0.830              |
|           |IRV                 |0.847              |0.811              |
|           |Multitask network   |0.775              |0.765              |
|           |robust MT-NN        |0.785              |0.748              |
@@ -337,27 +356,36 @@ Scaffold splitting
|Dataset         |Model               |Splitting   |Train score/R2|Valid score/R2|
|----------------|--------------------|------------|--------------|--------------|
|delaney         |Random Forest       |Index       |0.953         |0.626         |
|                |XGBoost             |Index       |0.898         |0.664         |
|                |NN regression       |Index       |0.868         |0.578         |
|                |graphconv regression|Index       |0.967         |0.790         |
|                |Random Forest       |Random      |0.951         |0.684         |
|                |XGBoost             |Random      |0.927         |0.727         |
|                |NN regression       |Random      |0.865         |0.574         |
|                |graphconv regression|Random      |0.964         |0.782         |
|                |Random Forest       |Scaffold    |0.953         |0.284         |
|                |XGBoost             |Scaffold    |0.890         |0.316         |
|                |NN regression       |Scaffold    |0.866         |0.342         |
|                |graphconv regression|Scaffold    |0.967         |0.606         |
|sampl           |Random Forest       |Index       |0.968         |0.736         |
|                |XGBoost             |Index       |0.884         |0.784         |
|                |NN regression       |Index       |0.917         |0.764         |
|                |graphconv regression|Index       |0.982         |0.864         |
|                |Random Forest       |Random      |0.967         |0.752         |
|                |XGBoost             |Random      |0.906         |0.745         |
|                |NN regression       |Random      |0.908         |0.830         |
|                |graphconv regression|Random      |0.987         |0.868         |
|                |Random Forest       |Scaffold    |0.966         |0.473         |
|                |XGBoost             |Scaffold    |0.918         |0.439         |
|                |NN regression       |Scaffold    |0.891         |0.217         |
|                |graphconv regression|Scaffold    |0.985         |0.666         |
|nci             |NN regression       |Index       |0.171         |0.062         |
|nci             |XGBoost             |Index       |0.441         |0.066         |
|                |NN regression       |Index       |0.171         |0.062         |
|                |graphconv regression|Index       |0.123         |0.048         |
|                |XGBoost             |Random      |0.409         |0.106         |
|                |NN regression       |Random      |0.168         |0.085         |
|                |graphconv regression|Random      |0.117         |0.076         |
|                |XGBoost             |Scaffold    |0.445         |0.046         |
|                |NN regression       |Scaffold    |0.180         |0.052         |
|                |graphconv regression|Scaffold    |0.131         |0.046         |
|pdbbind(core)   |Random Forest       |Random      |0.969         |0.445         |
@@ -418,36 +446,43 @@ Time needed for benchmark test(~20h in total)
|Dataset         |Model               |Time(loading)/s |Time(running)/s|
|----------------|--------------------|----------------|---------------|
|tox21           |logistic regression |30              |60             |
|                |XGBoost             |30              |1500           |
|                |Multitask network   |30              |60             |
|                |robust MT-NN        |30              |90             |
|                |random forest       |30              |6000           |
|                |IRV                 |30              |650            |
|                |graph convolution   |40              |160            |
|muv             |logistic regression |600             |450            |
|                |XGBoost             |600             |3492           |
|                |Multitask network   |600             |400            |
|                |robust MT-NN        |600             |550            |
|                |graph convolution   |800             |1800           |
|pcba            |logistic regression |1800            |10000          |
|                |XGBoost             |1800            |470521         |
|                |Multitask network   |1800            |9000           |
|                |robust MT-NN        |1800            |14000          |
|                |graph convolution   |2200            |14000          |
|sider           |logistic regression |15              |80             |
|                |XGBoost             |15              |660            |
|                |Multitask network   |15              |75             |
|                |robust MT-NN        |15              |150            |
|                |random forest       |15              |2200           |
|                |IRV                 |15              |150            |
|                |graph convolution   |20              |50             |
|toxcast         |logistic regression |80              |2600           |
|                |XGBoost             |80              |30000          |
|                |Multitask network   |80              |2300           |
|                |robust MT-NN        |80              |4000           |
|                |graph convolution   |80              |900            |
|clintox         |logistic regression |15              |10             |
|                |XGBoost             |15              |33             |
|                |Multitask network   |15              |20             |
|                |robust MT-NN        |15              |30             |
|                |random forest       |15              |200            |
|                |IRV                 |15              |10             |
|                |graph convolution   |20              |130            |
|hiv             |logistic regression |180             |40             |
|                |XGBoost             |180             |1062           |
|                |Multitask network   |180             |350            |
|                |robust MT-NN        |180             |450            |
|                |random forest       |180             |2800           |
@@ -456,11 +491,14 @@ Time needed for benchmark test(~20h in total)
|delaney         |MT-NN regression    |10              |40             |
|                |graphconv regression|10              |40             |
|                |random forest       |10              |30             |
|                |XGBoost             |10              |51             |
|sampl           |MT-NN regression    |10              |30             |
|                |graphconv regression|10              |40             |
|                |random forest       |10              |20             |
|                |XGBoost             |10              |18             |
|nci             |MT-NN regression    |400             |1200           |
|                |graphconv regression|400             |2500           |
|                |XGBoost             |400             |28096          |
|pdbbind(core)   |MT-NN regression    |0(featurized)   |30             |
|pdbbind(refined)|MT-NN regression    |0(featurized)   |40             |
|pdbbind(full)   |MT-NN regression    |0(featurized)   |60             |
+1 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from __future__ import unicode_literals

from deepchem.models.models import Model
from deepchem.models.sklearn_models import SklearnModel
from deepchem.models.xgboost_models import XGBoostModel
from deepchem.models.tf_new_models.multitask_classifier import MultitaskGraphClassifier
from deepchem.models.tf_new_models.multitask_regressor import MultitaskGraphRegressor
from deepchem.models.tf_new_models.DTNN_regressor import DTNNGraphRegressor
+92 −1
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ import deepchem as dc
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
import xgboost


class TestGeneralize(unittest.TestCase):
@@ -188,3 +189,93 @@ class TestGeneralize(unittest.TestCase):
  #  scores = model.evaluate(test_dataset, [classification_metric])
  #  for score in scores[classification_metric.name]:
  #    assert score > .5

  def test_xgboost_regression(self):
    np.random.seed(123)

    dataset = sklearn.datasets.load_diabetes()
    X, y = dataset.data, dataset.target
    frac_train = .7
    n_samples = len(X)
    n_train = int(frac_train * n_samples)
    X_train, y_train = X[:n_train], y[:n_train]
    X_test, y_test = X[n_train:], y[n_train:]
    train_dataset = dc.data.NumpyDataset(X_train, y_train)
    test_dataset = dc.data.NumpyDataset(X_test, y_test)

    regression_metric = dc.metrics.Metric(dc.metrics.mae_score)
    # Set early stopping round = n_estimators so that esr won't work
    esr = {'early_stopping_rounds': 50}
    xgb_model = xgboost.XGBRegressor(n_estimators=50, seed=123)
    model = dc.models.XGBoostModel(xgb_model, verbose=False, **esr)

    # Fit trained model
    model.fit(train_dataset)
    model.save()

    # Eval model on test
    scores = model.evaluate(test_dataset, [regression_metric])
    assert scores[regression_metric.name] < 50

  def test_xgboost_multitask_regression(self):
    """Test that xgboost models can learn on simple multitask regression."""
    np.random.seed(123)
    n_tasks = 4
    tasks = range(n_tasks)
    dataset = sklearn.datasets.load_diabetes()
    X, y = dataset.data, dataset.target
    y = np.reshape(y, (len(y), 1))
    y = np.hstack([y] * n_tasks)

    frac_train = .7
    n_samples = len(X)
    n_train = int(frac_train * n_samples)
    X_train, y_train = X[:n_train], y[:n_train]
    X_test, y_test = X[n_train:], y[n_train:]
    train_dataset = dc.data.DiskDataset.from_numpy(X_train, y_train)
    test_dataset = dc.data.DiskDataset.from_numpy(X_test, y_test)

    regression_metric = dc.metrics.Metric(dc.metrics.mae_score)
    esr = {'early_stopping_rounds': 50}

    def model_builder(model_dir):
      xgb_model = xgboost.XGBRegressor(n_estimators=50, seed=123)
      return dc.models.XGBoostModel(xgb_model, model_dir, verbose=False, **esr)

    model = dc.models.SingletaskToMultitask(tasks, model_builder)

    # Fit trained model
    model.fit(train_dataset)
    model.save()

    # Eval model on test
    scores = model.evaluate(test_dataset, [regression_metric])
    for score in scores[regression_metric.name]:
      assert score < 50

  def test_xgboost_classification(self):
    """Test that sklearn models can learn on simple classification datasets."""
    np.random.seed(123)
    dataset = sklearn.datasets.load_digits(n_class=2)
    X, y = dataset.data, dataset.target

    frac_train = .7
    n_samples = len(X)
    n_train = int(frac_train * n_samples)
    X_train, y_train = X[:n_train], y[:n_train]
    X_test, y_test = X[n_train:], y[n_train:]
    train_dataset = dc.data.NumpyDataset(X_train, y_train)
    test_dataset = dc.data.NumpyDataset(X_test, y_test)

    classification_metric = dc.metrics.Metric(dc.metrics.roc_auc_score)
    esr = {'early_stopping_rounds': 50}
    xgb_model = xgboost.XGBClassifier(n_estimators=50, seed=123)
    model = dc.models.XGBoostModel(xgb_model, verbose=False, **esr)

    # Fit trained model
    model.fit(train_dataset)
    model.save()

    # Eval model on test
    scores = model.evaluate(test_dataset, [classification_metric])
    assert scores[classification_metric.name] > .9
+120 −0
Original line number Diff line number Diff line
"""
Scikit-learn wrapper interface of xgboost
"""

import xgboost as xgb
import numpy as np
import os
from deepchem.models import Model
from deepchem.models.sklearn_models import SklearnModel
from deepchem.utils.save import load_from_disk
from deepchem.utils.save import save_to_disk
from sklearn.cross_validation import train_test_split
from sklearn.grid_search import GridSearchCV
import tempfile


class XGBoostModel(SklearnModel):
  """
  Abstract base class for XGBoost model.
  """

  def __init__(self,
               model_instance=None,
               model_dir=None,
               verbose=False,
               **kwargs):
    """Abstract class for XGBoost models.
    Parameters:
    -----------
    model_instance: object
      Scikit-learn wrapper interface of xgboost
    model_dir: str
      Path to directory where model will be stored.
    """
    if model_dir is not None:
      if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    else:
      model_dir = tempfile.mkdtemp()
    self.model_dir = model_dir
    self.model_instance = model_instance
    self.model_class = model_instance.__class__

    self.verbose = verbose
    if 'early_stopping_rounds' in kwargs:
      self.early_stopping_rounds = kwargs['early_stopping_rounds']
    else:
      self.early_stopping_rounds = 50

  def fit(self, dataset, **kwargs):
    """
    Fits XGBoost model to data.
    """
    X = dataset.X
    y = np.squeeze(dataset.y)
    w = np.squeeze(dataset.w)
    seed = self.model_instance.seed
    if isinstance(self.model_instance, xgb.XGBClassifier):
      xgb_metric = "auc"
      sklearn_metric = "roc_auc"
      stratify = y
    elif isinstance(self.model_instance, xgb.XGBRegressor):
      xgb_metric = "mae"
      sklearn_metric = "neg_mean_absolute_error"
      stratify = None
    best_param = self._search_param(sklearn_metric, X, y)
    # update model with best param
    self.model_instance = self.model_class(**best_param)

    # Find optimal n_estimators based on original learning_rate
    # and early_stopping_rounds
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=seed, stratify=stratify)

    self.model_instance.fit(
        X_train,
        y_train,
        early_stopping_rounds=self.early_stopping_rounds,
        eval_metric=xgb_metric,
        eval_set=[(X_train, y_train), (X_test, y_test)],
        verbose=self.verbose)
    # Since test size is 20%, when retrain model to whole data, expect
    # n_estimator increased to 1/0.8 = 1.25 time.
    estimated_best_round = np.round(self.model_instance.best_ntree_limit * 1.25)
    self.model_instance.n_estimators = np.int64(estimated_best_round)
    self.model_instance.fit(X, y, eval_metric=xgb_metric, verbose=self.verbose)

  def _search_param(self, metric, X, y):
    '''
    Find best potential parameters set using few n_estimators
    '''
    # Make sure user specified params are in the grid.
    max_depth_grid = list(np.unique([self.model_instance.max_depth, 5, 7]))
    colsample_bytree_grid = list(
        np.unique([self.model_instance.colsample_bytree, 0.66, 0.9]))
    reg_lambda_grid = list(np.unique([self.model_instance.reg_lambda, 1, 5]))
    param_grid = {
        'max_depth': max_depth_grid,
        'learning_rate': [max(self.model_instance.learning_rate, 0.3)],
        'n_estimators': [min(self.model_instance.n_estimators, 60)],
        'gamma': [self.model_instance.gamma],
        'min_child_weight': [self.model_instance.min_child_weight],
        'max_delta_step': [self.model_instance.max_delta_step],
        'subsample': [self.model_instance.subsample],
        'colsample_bytree': colsample_bytree_grid,
        'colsample_bylevel': [self.model_instance.colsample_bylevel],
        'reg_alpha': [self.model_instance.reg_alpha],
        'reg_lambda': reg_lambda_grid,
        'scale_pos_weight': [self.model_instance.scale_pos_weight],
        'base_score': [self.model_instance.base_score],
        'seed': [self.model_instance.seed]
    }
    grid_search = GridSearchCV(
        self.model_instance, param_grid, cv=2, refit=False, scoring=metric)
    grid_search.fit(X, y)
    best_params = grid_search.best_params_
    # Change params back original params
    best_params['learning_rate'] = self.model_instance.learning_rate
    best_params['n_estimators'] = self.model_instance.n_estimators
    return best_params
+188 −23

File changed.

Preview size limit exceeded, changes collapsed.

Loading