Commit 1059672d authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes in response to review.

parent aee5a160
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@ from deep_chem.utils.evaluate import compute_r2_scores

# TODO(rbharath): Factor this out into a separate function in utils. Duplicates
# code in deep.py
# TODO(rbharath): paths is to handle sharded input pickle files. Might be
# better to use hdf5 datasets like in MSMBuilder
def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random"):
  """Loads 3D Convolution datasets.

@@ -50,7 +52,6 @@ def fit_3D_convolution(paths, task_types, task_transforms, axis_length=32, **tra
      modeltype="keras", mode="tensor")
  local_task_types = task_types.copy()
  r2s = compute_r2_scores(results, local_task_types)
  if r2s:
  print "Mean R^2: %f" % np.mean(np.array(r2s.values()))

def train_3D_convolution(X, y, axis_length=32, batch_size=50, nb_epoch=1):
@@ -82,8 +83,7 @@ def train_3D_convolution(X, y, axis_length=32, batch_size=50, nb_epoch=1):
  nb_conv = [7, 5, 3]

  model = Sequential()
  # TODO(rbharath): Avoid hard coding the number of staks here
  model.add(Convolution3D(nb_filter=nb_filters[0], stack_size=3,
  model.add(Convolution3D(nb_filter=nb_filters[0], stack_size=n_channels,
     nb_row=nb_conv[0], nb_col=nb_conv[0], nb_depth=nb_conv[0],
     border_mode='valid'))
  model.add(Activation('relu'))