Commit 3971bd3c authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Batch of changes to fix vanilla multitask networks.

parent 2c0ea7f2
Loading
Loading
Loading
Loading
+9 −3
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ def process_multitask(paths, task_transforms, splittype="random",
    Seed used for random splits.
  """
  dataset = load_and_transform_dataset(paths, task_transforms,
			prediction_endpoint,
      weight_positives=weight_positives)
  sorted_targets = sorted(dataset.keys())
  if splittype == "random":
@@ -63,7 +64,9 @@ def process_multitask(paths, task_transforms, splittype="random",
  #  ensure_balanced(y_test, W_test)
  return (train, X_train, y_train, W_train, test, X_test, y_test, W_test)

def process_singletask(paths, task_transforms, splittype="random", seed=None,
def process_singletask(paths, task_transforms,
		prediction_endpoint,
		splittype="random", seed=None,
    weight_positives=True):
  """Extracts singletask datasets and splits into train/test.

@@ -82,6 +85,7 @@ def process_singletask(paths, task_transforms, splittype="random", seed=None,
    Seed used for random splits.
  """
  dataset = load_and_transform_dataset(paths, task_transforms,
			prediction_endpoint,
      weight_positives=weight_positives)
  singletask = multitask_to_singletask(dataset)
  arrays = {}
@@ -101,7 +105,7 @@ def process_singletask(paths, task_transforms, splittype="random", seed=None,
  return arrays


def fit_multitask_mlp(paths, task_types, task_transforms,
def fit_multitask_mlp(paths, task_types, task_transforms, prediction_endpoint,
                      splittype="random", weight_positives=False, **training_params):
  """
  Perform stochastic gradient descent optimization for a keras multitask MLP.
@@ -137,6 +141,7 @@ def fit_multitask_mlp(paths, task_types, task_transforms,
    print "Mean R^2: %f" % np.mean(np.array(r2s.values()))

def fit_singletask_mlp(paths, task_types, task_transforms,
											 prediction_endpoint,
                       splittype="random", weight_positives=True,
                       num_to_train=None, **training_params):
  """
@@ -154,6 +159,7 @@ def fit_singletask_mlp(paths, task_types, task_transforms,
    Aggregates keyword parameters to pass to train_multitask_model
  """
  singletasks = process_singletask(paths, task_transforms,
		prediction_endpoint,
    splittype=splittype, weight_positives=weight_positives)
  ret_vals = {}
  aucs, r2s, rms = {}, {}, {}
@@ -163,7 +169,7 @@ def fit_singletask_mlp(paths, task_types, task_transforms,
  for index, target in enumerate(sorted_targets):
    print "Training model %d" % index
    print "Target %s" % target
    (train, X_train, y_train, W_train, test, X_test, y_test, W_test) = (
    (train, X_train, y_train, W_train), (test, X_test, y_test, W_test) = (
        singletasks[target])
    model = train_multitask_model(X_train, y_train, W_train,
        {target: task_types[target]}, **training_params)
+6 −4
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from deep_chem.utils.evaluate import compute_r2_scores
# code in deep.py
# TODO(rbharath): paths is to handle sharded input pickle files. Might be
# better to use hdf5 datasets like in MSMBuilder
def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random"):
def process_3D_convolutions(paths, task_transforms, prediction_endpoint, seed=None, splittype="random"):
  """Loads 3D Convolution datasets.

  Parameters
@@ -25,7 +25,8 @@ def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random
  paths: list
    List of paths to convolution datasets.
  """
  dataset = load_and_transform_dataset(paths, task_transforms, datatype="pdbbind")
  dataset = load_and_transform_dataset(paths, task_transforms,
    prediction_endpoint, datatype="pdbbind")
  # TODO(rbharath): Factor this code splitting out into a util function.
  if splittype == "random":
    train, test = train_test_random_split(dataset, seed=seed)
@@ -35,12 +36,13 @@ def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random
  X_test, y_test, W_test = tensor_dataset_to_numpy(test)
  return (X_train, y_train, W_train, train), (X_test, y_test, W_test, test)

def fit_3D_convolution(paths, task_types, task_transforms, axis_length=32, **training_params):
def fit_3D_convolution(paths, task_types, task_transforms, prediction_endpoint,
    axis_length=32, **training_params):
  """
  Perform stochastic gradient descent for a 3D CNN.
  """
  (X_train, y_train, W_train, train), (X_test, y_test, W_test, test) = process_3D_convolutions(
    paths, task_transforms)
    paths, task_transforms, prediction_endpoint)

  print "np.shape(X_train): " + str(np.shape(X_train))
  print "np.shape(y_train): " + str(np.shape(y_train))
+9 −3
Original line number Diff line number Diff line
@@ -25,6 +25,8 @@ def parse_args(input_args=None):
  parser.add_argument("--splittype", type=str, default="scaffold",
                       choices=["scaffold", "random"],
                       help="Type of cross-validation data-splitting.")
  parser.add_argument("--prediction-endpoint", type=str, default="IC50",
                       help="Name of measured endpoint to predict.")
  parser.add_argument("--n-hidden", type=int, default=500,
                      help="Number of hidden neurons for NN models.")
  parser.add_argument("--learning-rate", type=float, default=0.01,
@@ -53,7 +55,6 @@ def main():
  args = parse_args()
  paths = {}


  for dataset, path in zip(args.datasets, args.paths):
    paths[dataset] = path

@@ -61,20 +62,25 @@ def main():

  if args.model == "singletask_deep_network":
    fit_singletask_mlp(paths.values(), task_types, task_transforms,
      splittype=args.splittype, n_hidden=args.n_hidden,
      prediction_endpoint=args.prediction_endpoint,
      splittype=args.splittype, 
      n_hidden=args.n_hidden,
      learning_rate=args.learning_rate, dropout=args.dropout,
      nb_epoch=args.n_epochs, decay=args.decay, batch_size=args.batch_size,
      validation_split=args.validation_split,
      weight_positives=args.weight_positives, num_to_train=args.num_to_train)
  elif args.model == "multitask_deep_network":
    fit_multitask_mlp(paths.values(), task_types, task_transforms,
      splittype=args.splittype, n_hidden=args.n_hidden, learning_rate =
      prediction_endpoint=args.prediction_endpoint,
      splittype=args.splittype,
      n_hidden=args.n_hidden, learning_rate =
      args.learning_rate, dropout = args.dropout, batch_size=args.batch_size,
      nb_epoch=args.n_epochs, decay=args.decay,
      validation_split=args.validation_split,
      weight_positives=args.weight_positives)
  elif args.model == "3D_cnn":
    fit_3D_convolution(paths.values(), task_types, task_transforms,
        prediction_endpoint=args.prediction_endpoint,
        axis_length=args.axis_length, nb_epoch=args.n_epochs,
        batch_size=args.batch_size)
  else:
+1 −1
Original line number Diff line number Diff line
# Usage ./process_bace.sh INPUT_SDF_FILE
python -m deep_chem.scripts.process_dataset --input_file $1 --input-type sdf --fields Name smiles pIC50 --field-types string string concentration --name BACE --out /tmp/
python -m deep_chem.scripts.process_dataset --input-file $1 --input-type sdf --fields Name smiles pIC50 --field-types string string concentration --name BACE --out /tmp/
+21 −3
Original line number Diff line number Diff line
@@ -105,8 +105,20 @@ def get_rows(input_file, input_type):
      df = pickle.load(f)
    return df.iterrows()
  elif input_type == "sdf":
    if ".gz" in input_file:
      print "gzipped"
      with gzip.open(input_file) as f:
      return Chem.ForwardSDMolSupplier(f)
        supp = Chem.ForwardSDMolSupplier(f)
        mols = [mol for mol in supp if mol is not None]
      print "len(mols): " + str(len(mols))
      return mols
    else:
      print "non-gzipped"
      with open(input_file) as f:
        supp  = Chem.ForwardSDMolSupplier(f)
        mols = [mol for mol in supp if mol is not None]
      print "len(mols): " + str(len(mols))
      return mols

def get_row_data(row, input_type, fields, field_types):
  """Extract information from row data."""
@@ -133,6 +145,7 @@ def get_row_data(row, input_type, fields, field_types):
        row_data[ind] = None
      else:
        row_data[ind] = mol.GetProp(field)
    return row_data

def process_field(data, field_type):
  """Parse data in a field."""
@@ -141,7 +154,11 @@ def process_field(data, field_type):
  elif field_type == "float":
    return parse_float_input(data)
  elif field_type == "concentration":
    return parse_float_input(data) / 1e-6
    fl = parse_float_input(data)
    if fl is not None:
      return parse_float_input(data) / 1e-7
    else:
      return None
  elif field_type == "list-string":
    return data.split(",")
  elif field_type == "list-float":
@@ -155,6 +172,7 @@ def generate_targets(input_file, input_type, fields, field_types, out_pkl,
  rows, mols, smiles = [], [], SmilesGenerator()
  for row_index, raw_row in enumerate(get_rows(input_file, input_type)):
    print row_index
    print raw_row
    # Skip row labels.
    if row_index == 0 or raw_row is None:
      continue
Loading