Commit 2c0ea7f2 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Merge branch 'master' of https://github.com/pandegroup/deep-learning into multitask

parents 416c577c 1d890efc
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -9,3 +9,4 @@ Requirements
* [sklearn](https://github.com/scikit-learn/scikit-learn.git)
* [numpy](https://store.continuum.io/cshop/anaconda/)
* [keras](keras.io)
+4 −5
Original line number Diff line number Diff line
@@ -10,12 +10,13 @@ from keras.utils import np_utils
from deep_chem.utils.preprocess import train_test_random_split
from deep_chem.utils.load import load_and_transform_dataset
from deep_chem.utils.preprocess import tensor_dataset_to_numpy
from deep_chem.datasets.shapes_3d import load_data
from deep_chem.utils.evaluate import eval_model
from deep_chem.utils.evaluate import compute_r2_scores

# TODO(rbharath): Factor this out into a separate function in utils. Duplicates
# code in deep.py
# TODO(rbharath): paths is to handle sharded input pickle files. Might be
# better to use hdf5 datasets like in MSMBuilder
def process_3D_convolutions(paths, task_transforms, seed=None, splittype="random"):
  """Loads 3D Convolution datasets.

@@ -50,7 +51,6 @@ def fit_3D_convolution(paths, task_types, task_transforms, axis_length=32, **tra
      modeltype="keras", mode="tensor")
  local_task_types = task_types.copy()
  r2s = compute_r2_scores(results, local_task_types)
  if r2s:
  print "Mean R^2: %f" % np.mean(np.array(r2s.values()))

def train_3D_convolution(X, y, axis_length=32, batch_size=50, nb_epoch=1):
@@ -82,8 +82,7 @@ def train_3D_convolution(X, y, axis_length=32, batch_size=50, nb_epoch=1):
  nb_conv = [7, 5, 3]

  model = Sequential()
  # TODO(rbharath): Avoid hard coding the number of staks here
  model.add(Convolution3D(nb_filter=nb_filters[0], stack_size=3,
  model.add(Convolution3D(nb_filter=nb_filters[0], stack_size=n_channels,
     nb_row=nb_conv[0], nb_col=nb_conv[0], nb_depth=nb_conv[0],
     border_mode='valid'))
  model.add(Activation('relu'))