Commit a01e688d authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Cleaning up

parent 8a015062
Loading
Loading
Loading
Loading
+0 −6
Original line number Diff line number Diff line
@@ -2694,7 +2694,6 @@ class DTNNEmbedding(tf.keras.layers.Layer):
          initializer=self.init,
          trainable=True)

    #init = initializers.get(self.init)
    self.embedding_list = init([self.periodic_table_length, self.n_embedding])
    self.built = True

@@ -2755,7 +2754,6 @@ class DTNNStep(tf.keras.layers.Layer):
          initializer=self.init,
          trainable=True)

    #init = initializers.get(self.init)
    self.W_cf = init([self.n_embedding, self.n_hidden])
    self.W_df = init([self.n_distance, self.n_hidden])
    self.W_fc = init([self.n_hidden, self.n_embedding])
@@ -2848,7 +2846,6 @@ class DTNNGather(tf.keras.layers.Layer):
          initializer=self.init,
          trainable=True)

    #init = initializers.get(self.init)
    prev_layer_size = self.n_embedding
    for i, layer_size in enumerate(self.layer_sizes):
      self.W_list.append(init([prev_layer_size, layer_size]))
@@ -3264,7 +3261,6 @@ class EdgeNetwork(tf.keras.layers.Layer):

    n_pair_features = self.n_pair_features
    n_hidden = self.n_hidden
    #init = initializers.get(self.init)
    self.W = init([n_pair_features, n_hidden * n_hidden])
    self.b = backend.zeros(shape=(n_hidden * n_hidden,))
    self.built = True
@@ -3302,7 +3298,6 @@ class GatedRecurrentUnit(tf.keras.layers.Layer):
          initializer=self.init,
          trainable=True)

    #init = initializers.get(self.init)
    self.Wz = init([n_hidden, n_hidden])
    self.Wr = init([n_hidden, n_hidden])
    self.Wh = init([n_hidden, n_hidden])
@@ -3365,7 +3360,6 @@ class SetGather(tf.keras.layers.Layer):
          initializer=self.init,
          trainable=True)

    #init = initializers.get(self.init)
    self.U = init((2 * self.n_hidden, 4 * self.n_hidden))
    self.b = tf.Variable(
        np.concatenate((np.zeros(self.n_hidden), np.ones(self.n_hidden),
+5 −10
Original line number Diff line number Diff line
@@ -773,6 +773,10 @@ def test_textCNN_classification_reload():
      model_dir=model_dir)
  reloaded_model.restore()

  # Eval model on train
  scores = reloaded_model.evaluate(dataset, [classification_metric])
  assert scores[classification_metric.name] > .8

  assert len(reloaded_model.model.get_weights()) == len(
      model.model.get_weights())
  for (reloaded, orig) in zip(reloaded_model.model.get_weights(),
@@ -785,18 +789,9 @@ def test_textCNN_classification_reload():
  predset = dc.data.NumpyDataset(Xpred, ids=predmols)
  origpred = model.predict(predset)
  reloadpred = reloaded_model.predict(predset)

  Xproc = reloaded_model.smiles_to_seq_batch(np.array(predmols))
  reloadout = reloaded_model.model(Xproc)
  origout = model.model(Xproc)

  assert len(model.model.layers) == len(reloaded_model.model.layers)

  assert np.all(origpred == reloadpred)

  # Eval model on train
  scores = reloaded_model.evaluate(dataset, [classification_metric])
  assert scores[classification_metric.name] > .8
  assert len(model.model.layers) == len(reloaded_model.model.layers)


def test_1d_cnn_regression_reload():