Commit 9bed9e66 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Merge pull request #6 from rbharath/multitask

Fix Broken fully-connected Networks
parents 1d890efc 24eaabee
Loading
Loading
Loading
Loading
+9 −3
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ def process_multitask(paths, task_transforms, splittype="random",
    Seed used for random splits.
  """
  dataset = load_and_transform_dataset(paths, task_transforms,
      prediction_endpoint,
      weight_positives=weight_positives)
  sorted_targets = sorted(dataset.keys())
  if splittype == "random":
@@ -63,7 +64,9 @@ def process_multitask(paths, task_transforms, splittype="random",
  #  ensure_balanced(y_test, W_test)
  return (train, X_train, y_train, W_train, test, X_test, y_test, W_test)

def process_singletask(paths, task_transforms, splittype="random", seed=None,
def process_singletask(paths, task_transforms,
    prediction_endpoint,
    splittype="random", seed=None,
    weight_positives=True):
  """Extracts singletask datasets and splits into train/test.

@@ -82,6 +85,7 @@ def process_singletask(paths, task_transforms, splittype="random", seed=None,
    Seed used for random splits.
  """
  dataset = load_and_transform_dataset(paths, task_transforms,
      prediction_endpoint,
      weight_positives=weight_positives)
  singletask = multitask_to_singletask(dataset)
  arrays = {}
@@ -101,7 +105,7 @@ def process_singletask(paths, task_transforms, splittype="random", seed=None,
  return arrays


def fit_multitask_mlp(paths, task_types, task_transforms,
def fit_multitask_mlp(paths, task_types, task_transforms, prediction_endpoint,
                      splittype="random", weight_positives=False, **training_params):
  """
  Perform stochastic gradient descent optimization for a keras multitask MLP.
@@ -137,6 +141,7 @@ def fit_multitask_mlp(paths, task_types, task_transforms,
    print "Mean R^2: %f" % np.mean(np.array(r2s.values()))

def fit_singletask_mlp(paths, task_types, task_transforms,
                       prediction_endpoint,
                       splittype="random", weight_positives=True,
                       num_to_train=None, **training_params):
  """
@@ -154,6 +159,7 @@ def fit_singletask_mlp(paths, task_types, task_transforms,
    Aggregates keyword parameters to pass to train_multitask_model
  """
  singletasks = process_singletask(paths, task_transforms,
    prediction_endpoint,
    splittype=splittype, weight_positives=weight_positives)
  ret_vals = {}
  aucs, r2s, rms = {}, {}, {}
@@ -163,7 +169,7 @@ def fit_singletask_mlp(paths, task_types, task_transforms,
  for index, target in enumerate(sorted_targets):
    print "Training model %d" % index
    print "Target %s" % target
    (train, X_train, y_train, W_train, test, X_test, y_test, W_test) = (
    (train, X_train, y_train, W_train), (test, X_test, y_test, W_test) = (
        singletasks[target])
    model = train_multitask_model(X_train, y_train, W_train,
        {target: task_types[target]}, **training_params)
+6 −4
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from deep_chem.utils.evaluate import compute_r2_scores
# code in deep.py
# TODO(rbharath): paths is to handle sharded input pickle files. Might be
# better to use hdf5 datasets like in MSMBuilder
def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random"):
def process_3D_convolutions(paths, task_transforms, prediction_endpoint, seed=None, splittype="random"):
  """Loads 3D Convolution datasets.

  Parameters
@@ -25,7 +25,8 @@ def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random
  paths: list
    List of paths to convolution datasets.
  """
  dataset = load_and_transform_dataset(paths, task_transforms, datatype="pdbbind")
  dataset = load_and_transform_dataset(paths, task_transforms,
    prediction_endpoint, datatype="pdbbind")
  # TODO(rbharath): Factor this code splitting out into a util function.
  if splittype == "random":
    train, test = train_test_random_split(dataset, seed=seed)
@@ -35,12 +36,13 @@ def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random
  X_test, y_test, W_test = tensor_dataset_to_numpy(test)
  return (X_train, y_train, W_train, train), (X_test, y_test, W_test, test)

def fit_3D_convolution(paths, task_types, task_transforms, axis_length=32, **training_params):
def fit_3D_convolution(paths, task_types, task_transforms, prediction_endpoint,
    axis_length=32, **training_params):
  """
  Perform stochastic gradient descent for a 3D CNN.
  """
  (X_train, y_train, W_train, train), (X_test, y_test, W_test, test) = process_3D_convolutions(
    paths, task_transforms)
    paths, task_transforms, prediction_endpoint)

  print "np.shape(X_train): " + str(np.shape(X_train))
  print "np.shape(y_train): " + str(np.shape(y_train))
+9 −3
Original line number Diff line number Diff line
@@ -25,6 +25,8 @@ def parse_args(input_args=None):
  parser.add_argument("--splittype", type=str, default="scaffold",
                       choices=["scaffold", "random"],
                       help="Type of cross-validation data-splitting.")
  parser.add_argument("--prediction-endpoint", type=str, default="IC50",
                       help="Name of measured endpoint to predict.")
  parser.add_argument("--n-hidden", type=int, default=500,
                      help="Number of hidden neurons for NN models.")
  parser.add_argument("--learning-rate", type=float, default=0.01,
@@ -53,7 +55,6 @@ def main():
  args = parse_args()
  paths = {}


  for dataset, path in zip(args.datasets, args.paths):
    paths[dataset] = path

@@ -61,20 +62,25 @@ def main():

  if args.model == "singletask_deep_network":
    fit_singletask_mlp(paths.values(), task_types, task_transforms,
      splittype=args.splittype, n_hidden=args.n_hidden,
      prediction_endpoint=args.prediction_endpoint,
      splittype=args.splittype, 
      n_hidden=args.n_hidden,
      learning_rate=args.learning_rate, dropout=args.dropout,
      nb_epoch=args.n_epochs, decay=args.decay, batch_size=args.batch_size,
      validation_split=args.validation_split,
      weight_positives=args.weight_positives, num_to_train=args.num_to_train)
  elif args.model == "multitask_deep_network":
    fit_multitask_mlp(paths.values(), task_types, task_transforms,
      splittype=args.splittype, n_hidden=args.n_hidden, learning_rate =
      prediction_endpoint=args.prediction_endpoint,
      splittype=args.splittype,
      n_hidden=args.n_hidden, learning_rate =
      args.learning_rate, dropout = args.dropout, batch_size=args.batch_size,
      nb_epoch=args.n_epochs, decay=args.decay,
      validation_split=args.validation_split,
      weight_positives=args.weight_positives)
  elif args.model == "3D_cnn":
    fit_3D_convolution(paths.values(), task_types, task_transforms,
        prediction_endpoint=args.prediction_endpoint,
        axis_length=args.axis_length, nb_epoch=args.n_epochs,
        batch_size=args.batch_size)
  else:
+2 −0
Original line number Diff line number Diff line
# Usage ./process_bace.sh INPUT_SDF_FILE
python -m deep_chem.scripts.process_dataset --input-file $1 --input-type sdf --fields Name smiles pIC50 --field-types string string concentration --name BACE --out /tmp/
+65 −32
Original line number Diff line number Diff line
@@ -16,17 +16,19 @@ from vs_utils.utils import SmilesGenerator
def parse_args(input_args=None):
  """Parse command-line arguments."""
  parser = argparse.ArgumentParser()
  parser.add_argument('--input-file', required=1,
                      help='Input file with data.')
  parser.add_argument("--input-file", required=1,
                      help="Input file with data.")
  parser.add_argument("--input-type", default="csv",
                      choices=["csv", "pandas"],
                      help="Type of input file. The pkl.gz must contain a pandas dataframe.")
  parser.add_argument("--columns", required=1, nargs="+",
                      help = "Names of columns.")
  parser.add_argument('--column-types', required=1, nargs="+",
                      choices=['string', 'float', 'list-string', 'list-float',
                               'ndarray'],
                      help='Type of data in columns.')
                      choices=["csv", "pandas", "sdf"],
                      help="Type of input file. If pandas, input must be a pkl.gz\n"
                           "containing a pandas dataframe. If sdf, should be in\n"
                           "(perhaps gzipped) sdf file.")
  parser.add_argument("--fields", required=1, nargs="+",
                      help = "Names of fields.")
  parser.add_argument("--field-types", required=1, nargs="+",
                      choices=["string", "float", "list-string", "list-float",
                               "ndarray", "concentration"],
                      help="Type of data in fields. Concentration is for molar concentrations.")
  parser.add_argument("--name", required=1,
                      help="Name of the dataset.")
  parser.add_argument("--out", required=1,
@@ -79,14 +81,17 @@ def generate_fingerprints(name, out):
                   "circular", "--size", "1024"])

def globavir_specs():
  columns = ["compound_name", "isomeric_smiles", "tdo_ic50_nm", "tdo_Ki_nm",
  fields = ["compound_name", "isomeric_smiles", "tdo_ic50_nm", "tdo_Ki_nm",
    "tdo_percent_activity_10_um", "tdo_percent_activity_1_um", "ido_ic50_nm",
    "ido_Ki_nm", "ido_percent_activity_10_um", "ido_percent_activity_1_um"]
  column_types = ["string", "string", "float", "float", "float", "float",
  field_types = ["string", "string", "float", "float", "float", "float",
      "float", "float", "float", "float"]

def get_rows(input_file, input_type):
  """Returns an iterator over all rows in input_file"""
  # TODO(rbharath): This function loads into memory, which can be painful. The
  # right option here might be to create a class which internally handles data
  # loading.
  if input_type == "xlsx":
    W = px.load_workbook(xlsx_file, use_iterators=True)
    p = W.get_sheet_by_name(name="Sheet1")
@@ -94,16 +99,24 @@ def get_rows(input_file, input_type):
  elif input_type == "csv":
    with open(csv_file, "rb") as f:
      reader = csv.reader(f, delimiter="\t")
    # TODO(rbharath): This loads into memory, which is painful. The right
    # option here might be to create a class which internally handles data
    # loading.
    return [row for row in reader]
  elif input_type == "pandas":
    with gzip.open(input_file) as f:
      df = pickle.load(f)
    return df.iterrows()
  elif input_type == "sdf":
    if ".gz" in input_file:
      with gzip.open(input_file) as f:
        supp = Chem.ForwardSDMolSupplier(f)
        mols = [mol for mol in supp if mol is not None]
      return mols
    else:
      with open(input_file) as f:
        supp  = Chem.ForwardSDMolSupplier(f)
        mols = [mol for mol in supp if mol is not None]
      return mols

def get_row_data(row, input_type, columns):
def get_row_data(row, input_type, fields, field_types):
  """Extract information from row data."""
  if input_type == "xlsx":
    return [cell.internal_value for cell in row]
@@ -112,37 +125,57 @@ def get_row_data(row, input_type, columns):
  elif input_type == "pandas":
    # pandas rows are tuples (row_num, row_info)
    row, row_data = row[1], {}
    # pandas rows are keyed by column-name. Change to key by index to match
    # pandas rows are keyed by field-name. Change to key by index to match
    # csv/xlsx handling
    for ind, column in enumerate(columns):
      row_data[ind] = row[column]
    for ind, field in enumerate(fields):
      row_data[ind] = row[field]
    return row_data
  elif input_type == "sdf":
    row_data, mol = {}, row
    for ind, (field, field_type) in enumerate(zip(fields, field_types)):
      # TODO(rbharath): SDF files typically don't have smiles, so we manually
      # generate smiles in this case. This is a kludgey solution...
      if field == "smiles":
        row_data[ind] = Chem.MolToSmiles(mol)
        continue
      if not mol.HasProp(field):
        row_data[ind] = None
      else:
        row_data[ind] = mol.GetProp(field)
    return row_data

def process_field(data, column_type):
def process_field(data, field_type):
  """Parse data in a field."""
  if column_type == "string":
  if field_type == "string":
    return data 
  elif column_type == "float":
  elif field_type == "float":
    return parse_float_input(data)
  elif column_type == "list-string":
  elif field_type == "concentration":
    fl = parse_float_input(data)
    if fl is not None:
      return parse_float_input(data) / 1e-7
    else:
      return None
  elif field_type == "list-string":
    return data.split(",")
  elif column_type == "list-float":
  elif field_type == "list-float":
    return np.array(data.split(","))
  elif column_type == "ndarray":
  elif field_type == "ndarray":
    return data 

def generate_targets(input_file, input_type, columns, column_types, out_pkl,
def generate_targets(input_file, input_type, fields, field_types, out_pkl,
    out_sdf):
  """Process input data file."""
  rows, mols, smiles = [], [], SmilesGenerator()
  for row_index, raw_row in enumerate(get_rows(input_file, input_type)):
    print row_index
    print raw_row
    # Skip row labels.
    if row_index == 0:
    if row_index == 0 or raw_row is None:
      continue
    row, row_data = {}, get_row_data(raw_row, input_type, columns)
    for ind, (column, column_type) in enumerate(zip(columns, column_types)):
      row[column] = process_field(row_data[ind], column_type)
    row, row_data = {}, get_row_data(raw_row, input_type, fields, field_types)
    for ind, (field, field_type) in enumerate(zip(fields, field_types)):
      row[field] = process_field(row_data[ind], field_type)
    # TODO(rbharath): This patch is only in place until the smiles/sequence
    # support is fixed.
    if row["smiles"] is None:
@@ -167,8 +200,8 @@ def generate_targets(input_file, input_type, columns, column_types, out_pkl,
def main():
  args = parse_args()
  out_pkl, out_sdf = generate_directories(args.name, args.out)
  generate_targets(args.input_file, args.input_type, args.columns,
      args.column_types, out_pkl, out_sdf)
  generate_targets(args.input_file, args.input_type, args.fields,
      args.field_types, out_pkl, out_sdf)
  generate_fingerprints(args.name, args.out)


Loading