Commit 385b857e authored by Peter Eastman's avatar Peter Eastman
Browse files

Merged changes from main branch

parents 56985e7f 6bbe6a48
Loading
Loading
Loading
Loading
+15 −8
Original line number Diff line number Diff line
@@ -496,26 +496,29 @@ Scaffold splitting
|                |Graphconv regression|Scaffold    |0.695         |0.391         |
|                |Weave regression    |Scaffold    |0.401         |0.373         |
|qm7             |NN regression       |Index       |0.997         |0.992         |
|                |DTNN                |Index       |0.998         |0.996         |
|                |DTNN                |Index       |0.997         |0.995         |
|                |NN regression       |Random      |0.998         |0.997         |
|                |DTNN                |Random      |0.998         |0.998         |
|                |DTNN                |Random      |0.999         |0.998         |
|                |NN regression       |Stratified  |0.998         |0.997         | 
|                |DTNN                |Stratified  |0.998         |0.998         | 
|qm7b            |MT-NN regression    |Index       |0.903         |0.789         |
|                |DTNN                |Index       |0.872         |0.821         |
|                |DTNN                |Index       |0.919         |0.863         |
|                |MT-NN regression    |Random      |0.893         |0.839         |
|                |DTNN                |Random      |0.865         |0.849         |
|                |DTNN                |Random      |0.924         |0.898         |
|                |MT-NN regression    |Stratified  |0.891         |0.859         | 
|                |DTNN                |Stratified  |0.853         |0.839         | 
|                |DTNN                |Stratified  |0.913         |0.894         | 
|qm8             |MT-NN regression    |Index       |0.783         |0.656         |
|                |DTNN                |Index       |0.737         |0.639         |
|                |DTNN                |Index       |0.857         |0.691         |
|                |MT-NN regression    |Random      |0.747         |0.660         |
|                |DTNN                |Random      |0.731         |0.711         |
|                |DTNN                |Random      |0.842         |0.756         |
|                |MT-NN regression    |Stratified  |0.756         |0.681         |
|                |DTNN                |Stratified  |0.714         |0.683         | 
|                |DTNN                |Stratified  |0.844         |0.758         | 
|qm9             |MT-NN regression    |Index       |0.733         |0.766         |
|                |DTNN                |Index       |0.918         |0.831         | 
|                |MT-NN regression    |Random      |0.852         |0.833         |
|                |DTNN                |Random      |0.942         |0.948         | 
|                |MT-NN regression    |Stratified  |0.764         |0.792         | 
|                |DTNN                |Stratified  |0.941         |0.867         | 
|sampl           |Random forest       |Index       |0.968         |0.736         |
|                |XGBoost             |Index       |0.884         |0.784         |
|                |NN regression       |Index       |0.917         |0.764         |
@@ -675,9 +678,13 @@ Time needed for benchmark test(~20h in total)
|                |Graphconv regression|20              |100            |
|                |Weave regression    |20              |120            |
|qm7             |MT-NN regression    |10              |400            |
|                |DTNN                |10              |600            |
|qm7b            |MT-NN regression    |10              |600            |
|                |DTNN                |10              |600            |
|qm8             |MT-NN regression    |60              |1000           |
|                |DTNN                |10              |2000           |
|qm9             |MT-NN regression    |220             |10000          |
|                |DTNN                |10              |14000          |
|sampl           |NN regression       |10              |30             |
|                |XGBoost             |10              |20             |
|                |Random forest       |10              |20             |
+5 −2
Original line number Diff line number Diff line
@@ -1049,8 +1049,11 @@ class Databag(object):
  A utility class to iterate through multiple datasets together.
  """

  def __init__(self):
  def __init__(self, datasets=None):
    if datasets is None:
      self.datasets = dict()
    else:
      self.datasets = datasets

  def add_dataset(self, key, dataset):
    self.datasets[key] = dataset
+12 −2
Original line number Diff line number Diff line
@@ -11,6 +11,8 @@ from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from scipy.stats import pearsonr


@@ -70,6 +72,14 @@ def pearson_r2_score(y, y_pred):
  return pearsonr(y, y_pred)[0]**2


def prc_auc_score(y, y_pred):
  """Compute area under precision-recall curve"""
  assert y_pred.shape == y.shape
  assert y_pred.shape[1] == 2
  precision, recall, _ = precision_recall_curve(y[:, 1], y_pred[:, 1])
  return auc(recall, precision)


def rms_score(y_true, y_pred):
  """Computes RMS error."""
  return np.sqrt(mean_squared_error(y_true, y_pred))
@@ -148,7 +158,7 @@ class Metric(object):
      if self.metric.__name__ in [
          "roc_auc_score", "matthews_corrcoef", "recall_score",
          "accuracy_score", "kappa_score", "precision_score",
          "balanced_accuracy_score"
          "balanced_accuracy_score", "prc_auc_score"
      ]:
        mode = "classification"
      elif self.metric.__name__ in [
@@ -267,7 +277,7 @@ class Metric(object):
      # TODO(rbharath): This has been a major source of bugs. Is there a more
      # robust characterization of which metrics require class-probs and which
      # don't?
      if "roc_auc_score" in self.name:
      if "roc_auc_score" in self.name or "prc_auc_score" in self.name:
        y_true = to_one_hot(y_true).astype(int)
        y_pred = np.reshape(y_pred, (n_samples, n_classes))
      else:
+2 −2
Original line number Diff line number Diff line
@@ -10,7 +10,6 @@ from deepchem.models.sklearn_models import SklearnModel
from deepchem.models.xgboost_models import XGBoostModel
from deepchem.models.tf_new_models.multitask_classifier import MultitaskGraphClassifier
from deepchem.models.tf_new_models.multitask_regressor import MultitaskGraphRegressor
from deepchem.models.tf_new_models.DTNN_regressor import DTNNGraphRegressor

from deepchem.models.tf_new_models.support_classifier import SupportGraphClassifier
from deepchem.models.multitask import SingletaskToMultitask
@@ -26,3 +25,4 @@ from deepchem.models.tensorflow_models.progressive_multitask import ProgressiveM
from deepchem.models.tensorflow_models.progressive_joint import ProgressiveJointRegressor
from deepchem.models.tensorflow_models.IRV import TensorflowMultiTaskIRVClassifier
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.models.tensorgraph.models.graph_models import WeaveTensorGraph, DTNNTensorGraph, DAGTensorGraph
+714 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading