Commit 87f92ee1 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

local

parent 728f6fd0
Loading
Loading
Loading
Loading
+15 −6
Original line number Diff line number Diff line
@@ -41,7 +41,7 @@
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
@@ -50,17 +50,26 @@
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rbharath/anaconda3/envs/deepchem/lib/python3.6/site-packages/matplotlib-1.5.3-py3.6-linux-x86_64.egg/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.\n",
      "  warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')\n",
      "/home/rbharath/anaconda3/envs/deepchem/lib/python3.6/site-packages/matplotlib-1.5.3-py3.6-linux-x86_64.egg/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.\n",
      "  warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')\n"
      "/home/rbharath/anaconda3/envs/deepchem/lib/python3.6/site-packages/matplotlib-1.5.3-py3.6-linux-x86_64.egg/matplotlib/__init__.py:1357: UserWarning:  This call to matplotlib.use() has no effect\n",
      "because the backend has already been chosen;\n",
      "matplotlib.use() must be called *before* pylab, matplotlib.pyplot,\n",
      "or matplotlib.backends is imported for the first time.\n",
      "\n",
      "  warnings.warn(_use_error_msg)\n",
      "Using Theano backend.\n",
      "/home/rbharath/anaconda3/envs/deepchem/lib/python3.6/site-packages/matplotlib-1.5.3-py3.6-linux-x86_64.egg/matplotlib/__init__.py:1357: UserWarning:  This call to matplotlib.use() has no effect\n",
      "because the backend has already been chosen;\n",
      "matplotlib.use() must be called *before* pylab, matplotlib.pyplot,\n",
      "or matplotlib.backends is imported for the first time.\n",
      "\n",
      "  warnings.warn(_use_error_msg)\n"
     ]
    }
   ],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "from dragonn.tutorial_utils import *\n",
    "from tutorial_utils import *\n",
    "%matplotlib inline"
   ]
  },
+5 −6
Original line number Diff line number Diff line
@@ -12,12 +12,17 @@ from dragonn.metrics import ClassificationResult
from sklearn.svm import SVC as scikit_SVC
from sklearn.tree import DecisionTreeClassifier as scikit_DecisionTree
from sklearn.ensemble import RandomForestClassifier
from keras.models import model_from_json
from keras.models import Sequential
from keras.layers.core import (Activation, Dense, Dropout, Flatten, Permute,
                                Reshape, TimeDistributedDense)
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.recurrent import GRU
from keras.regularizers import l1
from keras.layers.core import (Activation, Dense, Flatten,
                                TimeDistributedDense)
from keras.layers.recurrent import GRU
from keras.callbacks import EarlyStopping


class Model(object):
@@ -327,7 +332,6 @@ class SequenceDNN(Model):

  @staticmethod
  def load(arch_fname, weights_fname=None):
    from keras.models import model_from_json
    model_json_string = open(arch_fname).read()
    sequence_dnn = SequenceDNN(keras_model=model_from_json(model_json_string))
    if weights_fname is not None:
@@ -338,10 +342,6 @@ class SequenceDNN(Model):
class MotifScoreRNN(Model):

  def __init__(self, input_shape, gru_size=10, tdd_size=4):
    from keras.models import Sequential
    from keras.layers.core import (Activation, Dense, Flatten,
                                   TimeDistributedDense)
    from keras.layers.recurrent import GRU
    self.model = Sequential()
    self.model.add(
        GRU(gru_size, return_sequences=True, input_shape=input_shape))
@@ -354,7 +354,6 @@ class MotifScoreRNN(Model):
    self.model.compile(optimizer='adam', loss='binary_crossentropy')

  def train(self, X, y, validation_data):
    from keras.callbacks import EarlyStopping
    print('Training model...')
    multitask = y.shape[1] > 1
    if not multitask: