Commit 697edab1 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First steps for 3D_CNN support.

parent 4b69e74a
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ def process_multitask(paths, task_transforms, splittype="random",
  Parameters
  ----------
  paths: list 
    List of paths to Google vs datasets. 
    List of paths to datasets. 
  task_transforms: dict 
    dict mapping target names to label transform. Each output type must be either
    None, "log", "normalize" or "log-normalize". Only for regression outputs.
@@ -216,7 +216,7 @@ def train_multitask_model(X, y, W, task_types,
    Momentum used in SGD.
  nesterov: bool
    Use Nesterov acceleration
  n_epochs: int
  nb_epoch: int
    maximal number of epochs to run the optimizer
  """
  eps = .001
+50 −6
Original line number Diff line number Diff line
"""
Code for training 3D convolutions.
"""
from deep_chem.datasets.shapes_3d import load_data
from keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution3D, MaxPooling3D
from keras.utils import np_utils
import numpy as np

def fit_3D_convolution():
# TODO(rbharath): Factor this out into a separate function in utils. Duplicates
# code in deep.py
def process_3D_convolutions(paths, task_transforms, splittype="random"):
  """Loads 3D Convolution datasets.

  Parameters
  ----------
  paths: list
    List of paths to convolution datasets.
  """
  dataset = load_and_transform_dataset(paths, task_transforms)
  # TODO(rbharath): Factor this code splitting out into a util function.
  if splittype == "random":
    train, test = train_test_random_split(dataset, seed=seed)
  elif splittype == "scaffold":
    train, test = train_test_scaffold_split(dataset)
  X_train, y_train, W_train = dataset_to_numpy(train)
  X_test, y_test, W_test = dataset_to_numpy(test)
  return (X_train, y_train, W_train), (X_test, y_test, W_test)

def fit_3D_convolution(axis_length=32, **training_params):
  """
  Perform stochastic gradient descent for a 3D CNN.
  """
  pass
  nb_classes = 2
  (X_train, y_train), (X_test, y_test) = load_data(axis_length=axis_length)
  y_train = np_utils.to_categorical(y_train, nb_classes)
  y_test = np_utils.to_categorical(y_test, nb_classes)
  print "np.shape(X_train): " + str(np.shape(X_train))
  print "np.shape(y_train): " + str(np.shape(y_train))
  train_3D_convolution(X_train, y_train, axis_length, **training_params)

def train_3D_convolution(X, y):
def train_3D_convolution(X, y, axis_length=32, batch_size=50, nb_epoch=1):
  """
  Fit a keras 3D CNN to datat.

  Parameters
  ----------
  nb_epoch: int
    maximal number of epochs to run the optimizer
  """
  print "train_3D_convolution"
  print "axis_length: " + str(axis_length)
  # Number of classes for classification
  nb_classes = 2

  # number of convolutional filters to use at each layer
  nb_filters = [16, 32, 32]
  nb_filters = [axis_length/2, axis_length, axis_length]
  print "nb_filters: " + str(nb_filters)

  # level of pooling to perform at each layer (POOL x POOL)
  nb_pool = [2, 2, 2]
@@ -38,10 +82,10 @@ def train_3D_convolution(X, y):
  model.add(Activation('relu'))
  model.add(MaxPooling3D(poolsize=(nb_pool[2], nb_pool[2], nb_pool[2])))
  model.add(Flatten())
  model.add(Dense(32, 16, init='normal'))
  model.add(Dense(320, 32/2, init='normal'))
  model.add(Activation('relu'))
  model.add(Dropout(0.5))
  model.add(Dense(16, nb_classes, init='normal'))
  model.add(Dense(32/2, nb_classes, init='normal'))
  model.add(Activation('softmax'))

  sgd = RMSprop(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

deep_chem/public_data/__init__.py

deleted100644 → 0
+0 −0

Empty file deleted.

+14 −3
Original line number Diff line number Diff line
@@ -5,21 +5,23 @@ import argparse
import numpy as np
from deep_chem.models.deep import fit_singletask_mlp
from deep_chem.models.deep import fit_multitask_mlp
from deep_chem.models.deep3d import fit_3D_convolution
from deep_chem.models.standard import fit_singletask_models
from deep_chem.utils.load import get_default_task_types_and_transforms

def parse_args(input_args=None):
  """Parse command-line arguments."""
  parser = argparse.ArgumentParser()
  parser.add_argument('--datasets', required=1, nargs="+",
  parser.add_argument('--datasets', nargs="+", required=1,
                      choices=['muv', 'pcba', 'dude', 'pfizer', 'globavir', 'pdbbind'],
                      help='Name of dataset to process.')
  parser.add_argument("--paths", required=1, nargs="+",
  parser.add_argument("--paths", nargs="+", required=1,
                      help = "Paths to input datasets.")
  parser.add_argument('--model', required=1,
                      choices=["logistic", "rf_classifier", "rf_regressor",
                      "linear", "ridge", "lasso", "lasso_lars", "elastic_net",
                      "singletask_deep_network", "multitask_deep_network"])
                      "singletask_deep_network", "multitask_deep_network",
                      "3D_cnn"])
  parser.add_argument("--splittype", type=str, default="scaffold",
                       choices=["scaffold", "random"],
                       help="Type of cross-validation data-splitting.")
@@ -42,11 +44,16 @@ def parse_args(input_args=None):
  # TODO(rbharath): Remove this once debugging is complete.
  parser.add_argument("--num-to-train", type=int, default=None,
                  help="Number of datasets to train on. Only for debug.")
  parser.add_argument("--axis-length", type=int, default=32,
                  help="Size of a grid axis for 3D CNN input.")
      
  return parser.parse_args(input_args)

def main():
  args = parse_args()
  paths = {}


  for dataset, path in zip(args.datasets, args.paths):
    paths[dataset] = path

@@ -66,6 +73,10 @@ def main():
      nb_epoch=args.n_epochs, decay=args.decay,
      validation_split=args.validation_split,
      weight_positives=args.weight_positives)
  elif args.model == "3D_cnn":
    fit_3D_convolution(paths.values(), task_types, task_transforms,
        axis_length=args.axis_length, nb_epoch=args.n_epochs,
        batch_size=args.batch_size)
  else:
    fit_singletask_models(paths.values(), args.model, task_types,
        task_transforms, splittype=args.splittype, num_to_train=args.num_to_train)
+63 −37
Original line number Diff line number Diff line
@@ -18,11 +18,15 @@ def parse_args(input_args=None):
  parser = argparse.ArgumentParser()
  parser.add_argument('--input-file', required=1,
                      help='Input file with data.')
  parser.add_argument("--input-type", default="csv",
                      choices=["csv", "pandas"],
                      help="Type of input file. The pkl.gz must contain a pandas dataframe.")
  parser.add_argument("--columns", required=1, nargs="+",
                      help = "Names of columns.")
  parser.add_argument('--column-types', required=1, nargs="+",
                      choices=['string', 'float', 'list', 'float-array'],
                      help='Name of dataset to process.')
                      choices=['string', 'float', 'list-string', 'list-float',
                               'ndarray'],
                      help='Type of data in columns.')
  parser.add_argument("--name", required=1,
                      help="Name of the dataset.")
  parser.add_argument("--out", required=1,
@@ -81,48 +85,69 @@ def globavir_specs():
  column_types = ["string", "string", "float", "float", "float", "float",
      "float", "float", "float", "float"]

def gen_xlsx_rows(xlxs_file):
def get_rows(input_file, input_type):
  """Returns an iterator over all rows in input_file"""
  if input_type == "xlsx":
    W = px.load_workbook(xlsx_file, use_iterators=True)
    p = W.get_sheet_by_name(name="Sheet1")
    return p.iter_rows()
  elif input_type == "csv":
    with open(csv_file, "rb") as f:
      reader = csv.reader(f, delimiter="\t")
    # TODO(rbharath): This loads into memory, which is painful. The right
    # option here might be to create a class which internally handles data
    # loading.
    return [row for row in reader]
  elif input_type == "pandas":
    with gzip.open(input_file) as f:
      df = pickle.load(f)
    return df.iterrows()

def get_xlsx_row_data(row):
def get_row_data(row, input_type, columns):
  """Extract information from row data."""
  if input_type == "xlsx":
    return [cell.internal_value for cell in row]
  elif input_type == "csv":
    return row 
  elif input_type == "pandas":
    # pandas rows are tuples (row_num, row_info)
    row, row_data = row[1], {}
    # pandas rows are keyed by column-name. Change to key by index to match
    # csv/xlsx handling
    for ind, column in enumerate(columns):
      row_data[ind] = row[column]
    return row_data

def gen_csv_rows(csv_file):
  # This is a memory leak...
  f = open(csv_file, "rb")
  return csv.reader(f, delimiter="\t")
def process_field(data, column_type):
  """Parse data in a field."""
  if column_type == "string":
    return data 
  elif column_type == "float":
    return parse_float_input(data)
  elif column_type == "list-string":
    return data.split(",")
  elif column_type == "list-float":
    return np.array(data.split(","))
  elif column_type == "ndarray":
    return data 

def generate_targets(input_file, columns, column_types, out_pkl, out_sdf, type="csv"):
def generate_targets(input_file, input_type, columns, column_types, out_pkl,
    out_sdf):
  """Process input data file."""
  rows, mols = [], []
  smiles = SmilesGenerator()
  if type == "xlsx":
    row_gen = gen_xlsx_rows(input_file)
  elif type == "csv":
    row_gen = gen_csv_rows(input_file)
  for row_index, raw_row in enumerate(row_gen):
  rows, mols, smiles = [], [], SmilesGenerator()
  for row_index, raw_row in enumerate(get_rows(input_file, input_type)):
    print row_index
    # Skip row labels.
    if row_index == 0:
      continue
    if type == "xlsx":
      row_data = get_xlsx_row_data(raw_row)
    elif type == "csv":
      row_data = raw_row 
      
    row = {}
    row, row_data = {}, get_row_data(raw_row, input_type, columns)
    for ind, (column, column_type) in enumerate(zip(columns, column_types)):
      if column_type == "string":
        row[column] = row_data[ind]
      elif column_type == "float":
        row[column] = parse_float_input(row_data[ind])
      elif column_type == "list":
        row[column] = row_data[ind].split(",")
      elif column_type == "float-array":
        row[column] = np.array(row_data[ind].split(","))

      row[column] = process_field(row_data[ind], column_type)
    # TODO(rbharath): This patch is only in place until the smiles/sequence
    # support is fixed.
    if row["smiles"] is None:
      mol = Chem.MolFromSmiles("C")
    else:
      mol = Chem.MolFromSmiles(row["smiles"])
    row["smiles"] = smiles.get_smiles(mol)
    mols.append(mol)
@@ -141,7 +166,8 @@ def generate_targets(input_file, columns, column_types, out_pkl, out_sdf, type="
def main():
  args = parse_args()
  out_pkl, out_sdf = generate_directories(args.name, args.out)
  generate_targets(args.input_file, args.columns, args.column_types, out_pkl, out_sdf)
  generate_targets(args.input_file, args.input_type, args.columns,
      args.column_types, out_pkl, out_sdf)
  generate_fingerprints(args.name, args.out)