Commit a4ab24f3 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

STDNN now runs to completion, but numbers look a little suspect.

parent b191fe4c
Loading
Loading
Loading
Loading
+36 −23
Original line number Diff line number Diff line
@@ -51,17 +51,19 @@ class Model(object):
    """
    return(self.raw_model)

  def get_param_filename(self, out_dir):
  @staticmethod
  def get_model_filename(out_dir):
    """
    Given model directory, obtain filename for the model itself.
    """
    return os.path.join(out_dir, "model_params.joblib")
    return os.path.join(out_dir, "model.joblib")

  def get_model_filename(self, out_dir):
  @staticmethod
  def get_params_filename(out_dir):
    """
    Given model directory, obtain filename for the model itself.
    """
    return os.path.join(out_dir, "model.joblib")
    return os.path.join(out_dir, "model_params.joblib")

  @staticmethod
  def model_builder(model_type, task_types, model_params,
@@ -83,20 +85,28 @@ class Model(object):
    """
    Model.registered_model_types[model_type] = model_class

  def load(self, model_dir):
  @staticmethod
  def load(model_type, model_dir):
    """Dispatcher function for loading."""
    params = load_from_disk(self.get_model_filename(model_dir))
    self.model_params = params["model_params"]
    self.task_types = params["task_types"]
    self.model_type = params["model_type"]
    params = load_from_disk(Model.get_params_filename(model_dir))
    if model_type in Model.registered_model_types:
      model = Model.registered_model_types[model_type](
          model_type=params["model_type"],
          task_types=params["task_types"],
          model_params=params["model_params"])
      model.load(model_dir)
    else:
      raise ValueError("model_type %s is not supported" % model_type)
    return model

  def save(self, out_dir):
    """Dispatcher function for saving."""
    params = {"model_params" : self.model_params,
              "task_types" : self.task_types,
              "model_type": self.model_type}
    save_to_disk(params, self.get_params_filename(out_dir))
    save_to_disk(params, Model.get_params_filename(out_dir))

  # TODO(rbharath): This training is currently broken w.r.t minibatches! Fix.
  def fit(self, sharded_dataset):
    """
    Fits a model on data in a ShardedDataset object.
@@ -104,12 +114,15 @@ class Model(object):
    # TODO(rbharath/enf): This GPU_RAM is black magic. Needs to be removed/made
    # more general.
    MAX_GPU_RAM = float(691007488/50)
    for (X, y, w, _) in sharded_dataset.itershards():
    for epoch in range(self.model_params["nb_epoch"]):
      print("Starting epoch %s" % str(epoch+1))
      for i, (X, y, w, _) in enumerate(sharded_dataset.itershards()):
        print("Training on batch-%s/epoch-%s" % (str(i+1), str(epoch+1)))
        if sys.getsizeof(X) > MAX_GPU_RAM:
          nb_block = float(sys.getsizeof(X))/MAX_GPU_RAM
          nb_sample = np.shape(X)[0]
        interval_points = np.linspace(0,nb_sample,nb_block+1).astype(int)
        for j in range(0,len(interval_points)-1):
          interval_points = np.linspace(nb_sample,nb_block+1).astype(int)
          for j in range(len(interval_points)-1):
            indices = range(interval_points[j],interval_points[j+1])
            X_batch = X[indices,:]
            y_batch = y[indices]
+10 −9
Original line number Diff line number Diff line
"""
Code for processing the Google vs-datasets using keras.
"""
import os
import numpy as np
from keras.models import Graph
from keras.models import model_from_json
@@ -17,9 +18,9 @@ class KerasModel(Model):
    """
    Saves underlying keras model to disk. 
    """
    super(MultiTaskDNN, self).save(out_dir)
    super(KerasModel, self).save(out_dir)
    model = self.get_raw_model()
    filename, _ = os.path.splitext(self.get_model_filename(out_dir))
    filename, _ = os.path.splitext(Model.get_model_filename(out_dir))

    # Note that keras requires the model architecture and weights to be stored
    # separately. A json file is generated that specifies the model architecture.
@@ -37,8 +38,7 @@ class KerasModel(Model):
    """
    Load keras multitask DNN from disk.
    """
    super(MultiTaskDNN, self).load(model_dir)
    filename = self.get_Model_filename(model_dir)
    filename = Model.get_model_filename(model_dir)
    filename, _ = os.path.splitext(filename)

    json_filename = "%s.%s" % (filename, "json")
@@ -53,7 +53,8 @@ class MultiTaskDNN(KerasModel):
  """
  Model for multitask MLP in keras.
  """
  def __init__(self, task_types, model_params, initialize_raw_model=True):
  def __init__(self, model_type, task_types, model_params,
               initialize_raw_model=True):
    super(MultiTaskDNN, self).__init__(model_type, task_types, model_params,
                                       initialize_raw_model)
    if initialize_raw_model:
@@ -147,8 +148,8 @@ class SingleTaskDNN(MultiTaskDNN):
  """
  Abstract base class for different ML models.
  """
  def __init__(self, task_types, model_params, initialize_raw_model=True):
    super(SingleTaskDNN, self).__init__(task_types, model_params,
  def __init__(self, model_type, task_types, model_params, initialize_raw_model=True):
    super(SingleTaskDNN, self).__init__(model_type, task_types, model_params,
                                        initialize_raw_model)

Model.register_model_type("singletask_deep_regressor", SingleTaskDNN)
+1 −2
Original line number Diff line number Diff line
@@ -70,8 +70,7 @@ class SklearnModel(Model):

  def load(self, model_dir):
    """Loads sklearn model from joblib file on disk."""
    super(SklearnModel, self).load(model_dir)
    self.raw_model = joblib.load(self.get_model_filename(model_dir)
    self.raw_model = joblib.load(Model.get_model_filename(model_dir)

Model.register_model_type("logistic", SklearnModel)
Model.register_model_type("rf_classifier", SklearnModel)
+0 −3
Original line number Diff line number Diff line
@@ -262,9 +262,6 @@ def create_model(args):
  stats_out_train = os.path.join(data_dir, "train-stats.txt")
  csv_out_test = os.path.join(data_dir, "test.csv")
  stats_out_test = os.path.join(data_dir, "test-stats.txt")
  print("create_model()")
  print("args.output_transforms")
  print(args.output_transforms)
  train_dir = os.path.join(data_dir, "train")
  eval_trained_model(
      model_name, model_dir, train_dir, csv_out_train,
+22 −27
Original line number Diff line number Diff line
@@ -35,7 +35,6 @@ def df_to_numpy(df, mode, feature_types):

  y = df[sorted_tasks].values
  w = np.ones((n_samples, n_tasks))
  w[np.where(y=='')] = 0

  tensors = []
  for i, datapoint in df.iterrows():
@@ -44,27 +43,19 @@ def df_to_numpy(df, mode, feature_types):
      feature_list.append(datapoint[feature_type])
    features = np.squeeze(np.concatenate(feature_list))
    tensors.append(features)

  x = np.stack(tensors)
  sorted_ids = df['mol_id']
  return sorted_ids, x, y, w

def get_train_test_files(paths, splittype, train_proportion=0.8):
  """
  Randomly split files into train and test.
  """
  #all_files = []
  #for path in paths:
  #  all_files += glob(os.path.join(path, "*.joblib"))
  train_indices = list(
      np.random.choice(len(all_files), int(len(all_files)*train_proportion),
                       replace=False))
  test_indices = list(set(range(len(all_files)))-set(train_indices))

  train_files = [all_files[i] for i in train_indices]
  test_files = [f for f in all_files if f not in train_files]
  return train_files, test_files
  # Remove entries with missing labels
  nonzero_labels = np.where(np.squeeze(y)!='') 
  x = x[nonzero_labels]
  y = y[nonzero_labels]
  w = w[nonzero_labels]
  nonzero_rows = []
  for nonzero_ind in np.squeeze(nonzero_labels):
    nonzero_rows.append(df.iloc[nonzero_ind])
  sorted_ids = pd.DataFrame(nonzero_rows)["mol_id"]

  return sorted_ids, x, y, w

# TODO(rbharath): Should this be a method?
def get_sorted_task_names(df):
@@ -82,8 +73,11 @@ class FeaturizedSamples(object):
  """

  # The standard columns for featurized data.
  colnames = ["mol_id", "smiles", "split", "features", "descriptors",
              "fingerprints"]
  # TODO(rbharath): colnames are implicitly set in class Samples. Needs to be
  # moved into Samples to avoid bugs (ran into issues when chaning
  # "fingerprints" -> "ECFP") 
  colnames = ["mol_id", "smiles", "split", "user-specified-features", "RDKIT-descriptors",
              "ECFP"]

  def __init__(self, paths=None, dataset_files=[], compound_df=None):
    if paths is not None:
@@ -375,9 +369,9 @@ def _transform_row(i, df, normalize_X, normalize_y, truncate_X, truncate_y,
  save_to_disk(X, row['X-transformed'])

  y = load_from_disk(row['y'])
  w = load_from_disk(row['w'])
  if normalize_y or log_y:    
    if normalize_y:
      print("Normalizing y sample %d out of %d" % (i+1,total))
      y = np.nan_to_num((y - y_means) / y_stds)
      if truncate_y:
        y[y > trunc] = trunc
@@ -393,6 +387,10 @@ def compute_sums_and_nb_sample(tensor, W=None):

  If W is specified, only nonzero weight entries of tensor are used.
  """
  if len(np.shape(tensor)) == 1:
    tensor = np.reshape(tensor, (len(tensor), 1))
  if W is not None and len(np.shape(W)) == 1:
    W = np.reshape(W, (len(W), 1))
  if W is None:
    sums = np.sum(tensor, axis=0)
    sum_squares = np.sum(np.square(tensor), axis=0)
@@ -402,17 +400,14 @@ def compute_sums_and_nb_sample(tensor, W=None):
    sums = np.zeros((nb_task))
    sum_squares = np.zeros((nb_task))
    nb_sample = np.zeros((nb_task))
    for task in range(0, nb_task):
    for task in range(nb_task):
      y_task = tensor[:,task]
      W_task = W[:,task]
      nonzero_indices = np.nonzero(W_task)
      nonzero_indices = np.nonzero(W_task)[0]
      y_task_nonzero = y_task[nonzero_indices]
      sums[task] = np.sum(y_task_nonzero)
      sum_squares[task] = np.dot(y_task_nonzero, y_task_nonzero)
      nb_sample[task] = np.shape(y_task_nonzero)[0]
  print("compute_sums_and_nb_sample()")
  print("np.shape(tensor)")
  print(np.shape(tensor))
  return (sums, sum_squares, nb_sample)

def compute_mean_and_std(df):
Loading