Commit 4b4a572d authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Added support for weighting positives.

parent 6e65d9cf
Loading
Loading
Loading
Loading
+21 −8
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
from deep_chem.utils.load import load_datasets
from deep_chem.utils.load import ensure_balanced
from deep_chem.utils.preprocess import multitask_to_singletask
from deep_chem.utils.preprocess import train_test_random_split
from deep_chem.utils.preprocess import train_test_scaffold_split
@@ -21,7 +22,7 @@ from deep_chem.utils.evaluate import compute_roc_auc_scores
from deep_chem.utils.load import load_and_transform_dataset

def process_multitask(paths, task_transforms, desc_transforms, splittype="random",
    seed=None, add_descriptors=False, desc_weight=0.5):
    seed=None, add_descriptors=False, weight_positives=False, desc_weight=0.5):
  """Extracts multitask datasets and splits into train/test.

  Returns a tuple of test/train datasets, fingerprints, and labels.
@@ -45,7 +46,7 @@ def process_multitask(paths, task_transforms, desc_transforms, splittype="random
    Seed used for random splits.
  """
  dataset = load_and_transform_dataset(paths, task_transforms, desc_transforms,
      add_descriptors=add_descriptors)
      add_descriptors=add_descriptors, weight_positives=weight_positives)
  if splittype == "random":
    train, test = train_test_random_split(dataset, seed=seed)
  elif splittype == "scaffold":
@@ -54,12 +55,18 @@ def process_multitask(paths, task_transforms, desc_transforms, splittype="random
    raise ValueError("Improper splittype. Must be random/scaffold.")
  X_train, y_train, W_train = dataset_to_numpy(train,
      add_descriptors=add_descriptors, desc_weight=desc_weight)
  if weight_positives:
    print "Train set balance"
    ensure_balanced(y_train, W_train)
  X_test, y_test, W_test = dataset_to_numpy(test,
      add_descriptors=add_descriptors, desc_weight=desc_weight)
  if weight_positives:
    print "Test set balance"
    ensure_balanced(y_test, W_test)
  return (train, X_train, y_train, W_train, test, X_test, y_test, W_test)

def process_singletask(paths, task_transforms, desc_transforms, splittype="random", seed=None,
    add_descriptors=False, desc_weight=0.5):
    add_descriptors=False, desc_weight=0.5, weight_positives=True):
  """Extracts singletask datasets and splits into train/test.

  Returns a dict that maps target names to tuples.
@@ -77,11 +84,14 @@ def process_singletask(paths, task_transforms, desc_transforms, splittype="rando
    Seed used for random splits.
  """
  dataset = load_and_transform_dataset(paths, task_transforms, desc_transforms,
      add_descriptors=add_descriptors)
      add_descriptors=add_descriptors, weight_positives=weight_positives)
  singletask = multitask_to_singletask(dataset)
  arrays = {}
  for target in singletask:
    print target
    data = singletask[target]
    print "len(data)"
    print len(data)
    # TODO(rbharath): Remove limitation after debugging.
    if len(data) == 0:
      continue
@@ -102,7 +112,7 @@ def process_singletask(paths, task_transforms, desc_transforms, splittype="rando

def fit_multitask_mlp(paths, task_types, task_transforms, desc_transforms,
                      splittype="random", add_descriptors=False, desc_weight=0.5,
                      **training_params):
                      weight_positives=False, **training_params):
  """
  Perform stochastic gradient descent optimization for a keras multitask MLP.
  Returns AUCs, R^2 scores, and RMS values.
@@ -127,7 +137,8 @@ def fit_multitask_mlp(paths, task_types, task_transforms, desc_transforms,
  """
  (train, X_train, y_train, W_train, test, X_test, y_test, W_test) = (
      process_multitask(paths, task_transforms, desc_transforms,
      splittype=splittype, add_descriptors=add_descriptors, desc_weight=desc_weight))
      splittype=splittype, add_descriptors=add_descriptors, desc_weight=desc_weight,
      weight_positives=weight_positives))
  print np.shape(y_train)
  model = train_multitask_model(X_train, y_train, W_train, task_types,
                                desc_transforms, add_descriptors=add_descriptors,
@@ -150,7 +161,7 @@ def fit_multitask_mlp(paths, task_types, task_transforms, desc_transforms,
def fit_singletask_mlp(paths, task_types, task_transforms,
                       desc_transforms, splittype="random",
                       add_descriptors=False, desc_weight=0.5,
                       **training_params):
                       weight_positives=True, num_to_train=None, **training_params):
  """
  Perform stochastic gradient descent optimization for a keras MLP.

@@ -170,10 +181,12 @@ def fit_singletask_mlp(paths, task_types, task_transforms,
  """
  singletasks = process_singletask(paths, task_transforms, desc_transforms,
    splittype=splittype, add_descriptors=add_descriptors,
    desc_weight=desc_weight)
    desc_weight=desc_weight, weight_positives=weight_positives)
  ret_vals = {}
  aucs, r2s, rms = {}, {}, {}
  sorted_targets = sorted(singletasks.keys())
  if num_to_train:
    sorted_targets = sorted_targets[:num_to_train]
  for index, target in enumerate(sorted_targets):
    print "Training model %d" % index
    (train, X_train, y_train, W_train, test, X_test, y_test, W_test) = (
+5 −2
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@ from sklearn.svm import SVR

def fit_singletask_models(paths, modeltype, task_types, task_transforms,
    add_descriptors=False, desc_transforms={}, splittype="random",
    seed=None):
    seed=None, num_to_train=None):
  """Fits singletask linear regression models to potency.

  Parameters
@@ -52,7 +52,10 @@ def fit_singletask_models(paths, modeltype, task_types, task_transforms,
      add_descriptors=add_descriptors)
  singletask = multitask_to_singletask(dataset)
  aucs, r2s, rms = {}, {}, {}
  for index, target in enumerate(sorted(singletask.keys())):
  sorted_targets = sorted(singletask.keys())
  if num_to_train:
    sorted_targets = sorted_targets[:num_to_train]
  for index, target in enumerate(sorted_targets):
    print "Building model %d" % index
    data = singletask[target]
    if splittype == "random":
+10 −19
Original line number Diff line number Diff line
@@ -5,22 +5,7 @@ import argparse
import numpy as np
from deep_chem.models.deep import fit_singletask_mlp
from deep_chem.models.deep import fit_multitask_mlp
from deep_chem.models.deep import train_multitask_model
from deep_chem.models.standard import fit_singletask_models
from deep_chem.models.standard import fit_multitask_rf
from deep_chem.utils.analysis import compare_datasets
from deep_chem.utils.evaluate import eval_model
from deep_chem.utils.evaluate import compute_roc_auc_scores
from deep_chem.utils.evaluate import compute_r2_scores
from deep_chem.utils.evaluate import compute_rms_scores
from deep_chem.utils.load import get_target_names
from deep_chem.utils.load import load_datasets
from deep_chem.utils.load import load_and_transform_dataset
from deep_chem.utils.preprocess import dataset_to_numpy
from deep_chem.utils.preprocess import train_test_random_split
from deep_chem.utils.preprocess import train_test_scaffold_split
from deep_chem.utils.preprocess import scaffold_separate
from deep_chem.utils.preprocess import multitask_to_singletask
from deep_chem.utils.load import get_default_task_types_and_transforms
from deep_chem.utils.preprocess import get_default_descriptor_transforms

@@ -53,6 +38,11 @@ def parse_args(input_args=None):
                  help="Learning rate decay for NN models.")
  parser.add_argument("--validation-split", type=float, default=0.0,
                  help="Percent of training data to use for validation.")
  parser.add_argument("--weight-positives", type=bool, default=False,
                  help="Weight positive examples to have same total weight as negatives.")
  # TODO(rbharath): Remove this once debugging is complete.
  parser.add_argument("--num-to-train", type=int, default=None,
                  help="Number of datasets to train on. Only for debug.")
  return parser.parse_args(input_args)

def main():
@@ -70,16 +60,17 @@ def main():
      n_hidden=args.n_hidden, learning_rate=args.learning_rate,
      dropout=args.dropout, nb_epoch=args.n_epochs, decay=args.decay,
      batch_size=args.batch_size,
      validation_split=args.validation_split)
      validation_split=args.validation_split,
      weight_positives=args.weight_positives, num_to_train=args.num_to_train)
  elif args.model == "multitask_deep_network":
    fit_multitask_mlp(paths.values(), task_types, task_transforms,
      desc_transforms, splittype=args.splittype, add_descriptors=False,
      n_hidden=args.n_hidden, learning_rate = args.learning_rate, dropout = args.dropout,
      batch_size=args.batch_size,
      nb_epoch=args.n_epochs, decay=args.decay, validation_split=args.validation_split)
      batch_size=args.batch_size, nb_epoch=args.n_epochs, decay=args.decay,
      validation_split=args.validation_split, weight_positives=args.weight_positives)
  else:
    fit_singletask_models(paths.values(), args.model, task_types,
        task_transforms, splittype=args.splittype)
        task_transforms, splittype=args.splittype, num_to_train=args.num_to_train)

if __name__ == "__main__":
  main()
+19 −2
Original line number Diff line number Diff line
@@ -232,9 +232,22 @@ def load_vs_datasets(paths, target_dir_name="targets",
                      "labels": labels[smiles]}
  return data

def ensure_balanced(y, W):
  """Helper function that ensures postives and negatives are balanced."""
  n_samples, n_targets = np.shape(y)
  for target_ind in range(n_targets):
    pos_weight, neg_weight = 0, 0
    for sample_ind in range(n_samples):
      if y[sample_ind, target_ind] == 0:
        neg_weight += W[sample_ind, target_ind]
      elif y[sample_ind, target_ind] == 1:
        pos_weight += W[sample_ind, target_ind]
    assert np.isclose(pos_weight, neg_weight)
  print "WEIGHTS ARE BALANCED"

def load_and_transform_dataset(paths, task_transforms, desc_transforms={},
    labels_endpoint="labels", descriptors_endpoint="descriptors",
    add_descriptors=False):
    add_descriptors=False, weight_positives=True):
  """Transform data labels as specified

  Parameters
@@ -254,7 +267,11 @@ def load_and_transform_dataset(paths, task_transforms, desc_transforms={},
  """
  dataset = load_datasets(paths, add_descriptors=add_descriptors)
  X, y, W = transform_outputs(dataset, task_transforms,
      desc_transforms=desc_transforms, add_descriptors=add_descriptors)
      desc_transforms=desc_transforms, add_descriptors=add_descriptors,
      weight_positives=weight_positives)
  # TODO(rbharath): Take this out once test passes
  if weight_positives:
    ensure_balanced(y, W)
  trans_data = {}
  sorted_smiles = sorted(dataset.keys())
  sorted_targets = sorted(task_transforms.keys())
+33 −8
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@ def get_default_descriptor_transforms():
  return desc_transforms

def transform_outputs(dataset, task_transforms, desc_transforms={},
    add_descriptors=False):
    add_descriptors=False, weight_positives=True):
  """Tranform the provided outputs

  Parameters
@@ -36,7 +36,8 @@ def transform_outputs(dataset, task_transforms, desc_transforms={},
  add_descriptors: bool
    Add descriptor prediction as extra task.
  """
  X, y, W = dataset_to_numpy(dataset, add_descriptors=add_descriptors)
  X, y, W = dataset_to_numpy(dataset, add_descriptors=add_descriptors,
      weight_positives=weight_positives)
  sorted_targets = sorted(task_transforms.keys())
  if add_descriptors:
    sorted_descriptors = sorted(desc_transforms.keys())
@@ -91,9 +92,30 @@ def to_one_hot(y):
      y_hot[index] = np.array([0, 1])
  return y_hot

def balance_positives(y, W):
  """Ensure that positive and negative examples have equal weight."""
  n_samples, n_targets = np.shape(y)
  for target_ind in range(n_targets):
    positive_inds, negative_inds = [], []
    for sample_ind in range(n_samples):
      label = y[sample_ind, target_ind]
      if label == 1:
        positive_inds.append(sample_ind)
      elif label == 0:
        negative_inds.append(sample_ind)
      elif label == -1:  # Case of missing label
        continue
      else:
        raise ValueError("Labels must be 0/1 or -1 (missing data) for balance_positives.")
    n_positives, n_negatives = len(positive_inds), len(negative_inds)
    pos_weight = float(n_negatives)/float(n_positives)
    W[positive_inds, target_ind] = pos_weight
    W[negative_inds, target_ind] = 1
  return W

def dataset_to_numpy(dataset, feature_endpoint="fingerprint",
    labels_endpoint="labels", descriptors_endpoint="descriptors",
    desc_weight=.5, add_descriptors=False):
    desc_weight=.5, add_descriptors=False, weight_positives=True):
  """Transforms a loaded dataset into numpy arrays (X, y).

  Transforms provided dict into feature matrix X (of dimensions [n_samples,
@@ -106,6 +128,8 @@ def dataset_to_numpy(dataset, feature_endpoint="fingerprint",
  (this is relatively safe since the ratio of positive to negative examples
  is on the order 1/100)

  TODO(rbharath): Clean this up and remove some of the extra arguments.
  
  Parameters
  ----------
  dataset: dict 
@@ -137,7 +161,7 @@ def dataset_to_numpy(dataset, feature_endpoint="fingerprint",
    # Set labels from measurements
    for t_ind, target in enumerate(sorted_targets):
      if labels[target] == -1:
        y[index][t_ind] = 0
        y[index][t_ind] = -1
        W[index][t_ind] = 0
      else:
        y[index][t_ind] = labels[target]
@@ -145,6 +169,8 @@ def dataset_to_numpy(dataset, feature_endpoint="fingerprint",
      # Set labels from descriptors
      y[index][n_targets:] = descriptors
      W[index][n_targets:] = desc_weight
  if weight_positives:
    W = balance_positives(y, W)
  return X, y, W

def multitask_to_singletask(dataset):
@@ -162,15 +188,14 @@ def multitask_to_singletask(dataset):
  # Generate single-task data structures
  labels = dataset.itervalues().next()["labels"]
  sorted_targets = sorted(labels.keys())
  # TODO(rbharath): Replace this with a dictionary comprehension
  singletask = {}
  for target in sorted_targets:
    singletask[target] = {} 
  singletask = {target: {} for target in sorted_targets}
  # Populate the singletask datastructures
  sorted_smiles = sorted(dataset.keys())
  for index, smiles in enumerate(sorted_smiles):
    datapoint = dataset[smiles]
    labels = datapoint["labels"]
    if index < 10:
      print labels
    for t_ind, target in enumerate(sorted_targets):
      if labels[target] == -1:
        continue
Loading