Commit d4519b34 authored by evanfeinberg's avatar evanfeinberg
Browse files

fixed issue with border mode.

parent fde59eb8
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -14,7 +14,6 @@ class Model(object):
  def __init__(self, task_types, model_params, initialize_raw_model=True):
    self.task_types = task_types
    self.model_params = model_params
    self.raw_model = None

  def fit_on_batch(self, X, y, w):
    """
+17 −20
Original line number Diff line number Diff line
@@ -13,6 +13,17 @@ from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution3D, MaxPooling3D
from deep_chem.models import Model

def shuffle_shape(shape):
  (axis_length, _, _, n_channels) = shape
  shuffled_shape = (n_channels, axis_length, axis_length, axis_length)
  return shuffled_shape

def shuffle_data(X):
  (n_samples, axis_length, _, _, n_channels) = np.shape(X)
  X = np.reshape(X, (n_samples, n_channels, axis_length, axis_length, axis_length))
  return X


class DockingDNN(Model):
  """
  Wrapper class for fitting 3D convolutional networks for deep docking.
@@ -39,25 +50,24 @@ class DockingDNN(Model):

      model.add(Convolution3D(nb_filter=nb_filters[0], nb_depth=nb_conv[0], 
                              nb_row=nb_conv[0], nb_col=nb_conv[0],
                              input_shape=self.input_shape))
                              input_shape=self.input_shape, border_mode="full"))
      model.add(Activation('relu'))

      model.add(MaxPooling3D(pool_size=(nb_pool[0], nb_pool[0], nb_pool[0])))
      model.add(Convolution3D(nb_filter=nb_filters[1],  nb_depth=nb_conv[1],
                              nb_row=nb_conv[1], nb_col=nb_conv[1]))
                              nb_row=nb_conv[1], nb_col=nb_conv[1], border_mode="full"))
      model.add(Activation('relu'))
      model.add(MaxPooling3D(pool_size=(nb_pool[1], nb_pool[1], nb_pool[1])))
      model.add(Convolution3D(nb_filter=nb_filters[2], nb_depth=nb_conv[2],
                              nb_row=nb_conv[2], nb_col=nb_conv[2]))
                              nb_row=nb_conv[2], nb_col=nb_conv[2], border_mode="full"))
      model.add(Activation('relu'))
      model.add(MaxPooling3D(pool_size=(nb_pool[2], nb_pool[2], nb_pool[2])))
      model.add(Flatten())
      # TODO(rbharath): If we change away from axis-size 32, this code will break.
      # Eventually figure out a more general rule that works for all axis sizes.
      model.add(Dense(32/2, init='normal'))
      model.add(Dense(16, init='normal'))
      model.add(Activation('relu'))
      model.add(Dropout(0.5))
      # TODO(rbharath): Generalize this to support classification as well as regression.
      model.add(Dense(1, init='normal'))

      sgd = RMSprop(lr=learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
@@ -65,29 +75,16 @@ class DockingDNN(Model):
      model.compile(loss=loss_function, optimizer=sgd)
      self.raw_model = model

  def shuffle_data(self, X):
    (n_samples, axis_length, _, _, n_channels) = np.shape(X)
    X = np.reshape(X, (n_samples, n_channels, axis_length, axis_length, axis_length))
    return X


  def fit_on_batch(self, X, y, w):
    print("Training 3D model")
    print("Original shape of X: " + str(np.shape(X)))
    print("Shuffling X dimensions to match convnet")
    # TODO(rbharath): Modify the featurization so that it matches desired shaped.
    X = self.shuffle_data(X)
    print("Final shape of X: " + str(np.shape(X)))

    print("About to fit data to model.")
    X = shuffle_data(X)
    self.raw_model.train_on_batch(X, y)
    print("Finished training on batch.")

  def predict_on_batch(self, X):
    if len(np.shape(X)) != 5:
      raise ValueError(
          "Tensorial datatype must be of shape (n_samples, N, N, N, n_channels).")
    X = self.shuffle_data(X)
    X = shuffle_data(X)
    y_pred = self.raw_model.predict_on_batch(X)
    y_pred = np.squeeze(y_pred)
    return y_pred
+8 −4
Original line number Diff line number Diff line
@@ -11,16 +11,20 @@ from deep_chem.models.deep import MultiTaskDNN
from deep_chem.models.deep3d import DockingDNN
from deep_chem.models.standard import SklearnModel

def model_builder(model_type, task_types, model_params):
def model_builder(model_type, task_types, model_params,
                  initialize_raw_model=True):
  """
  Factory function to construct model.
  """
  if model_type == "singletask_deep_network":
    model = SingleTaskDNN(task_types, model_params)
    model = SingleTaskDNN(task_types, model_params,
                          initialize_raw_model)
  elif model_type == "multitask_deep_network":
    model = MultiTaskDNN(task_types, model_params)
    model = MultiTaskDNN(task_types, model_params,
                         initialize_raw_model)
  elif model_type == "convolutional_3D_regressor":
    model = DockingDNN(task_types, model_params)
    model = DockingDNN(task_types, model_params,
                       initialize_raw_model)
  else:
    model = SklearnModel(task_types, model_params)
  return model
+0 −2
Original line number Diff line number Diff line
@@ -310,9 +310,7 @@ def featurize_inputs_wrapper(args):
def train_test_split_wrapper(args):
  """Wrapper function that calls _train_test_split_wrapper after unwrapping args."""
  train_test_split(args.paths, args.output_transforms,

                   args.input_transforms, args.feature_types,

                   args.splittype, args.mode, args.data_dir)

def fit_model_wrapper(args):
+19 −16
Original line number Diff line number Diff line
@@ -45,8 +45,11 @@ def compute_y_pred(model, data_dir, csv_out, split):
  column_names = ['ids'] + task_names + pred_task_names + w_task_names
  pred_y_df = pd.DataFrame(columns=column_names)

  for _, row in metadata_df.iterrows():
    if row['split'] == split:
  split_df = metadata_df.loc[metadata_df['split'] == split]
  nb_batch = split_df.shape[0]

  for i, row in split_df.iterrows():
    print("Evaluating on %s batch %d out of %d" % (split, i+1, nb_batch))
    X = load_sharded_dataset(row['X'])
    y = load_sharded_dataset(row['y'])
    w = load_sharded_dataset(row['w'])
Loading