Commit 20f2022d authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Add support for model saving/loading

parent 6fbe8494
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -129,12 +129,20 @@ class ANIRegression(TensorGraph):
    self.build_graph()
    self.grad = None

  def save(self):
    self.grad = None # recompute grad on restore
    super(ANIRegression, self).save()

  def build_grad(self):
    self.grad = tf.gradients(self.outputs, self.atom_feats)

  def compute_grad(self, dataset, batch_size=1):
    with self._get_tf("Graph").as_default():
      if not self.built:
        self.build()
      if not self.grad:
        self.grad = tf.gradients(self.outputs, self.atom_feats)
        self.build_grad()

      feed_dict = dict()
      X = dataset.X
      flags = np.sign(np.array(X[:batch_size, :, 0]))
+26 −15
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@ import deepchem as dc
import pyanitools as pya
import app

import dill

def convert_species_to_atomic_nums(s):
  PERIODIC_TABLE = {"H": 1, "C": 6, "N": 7, "O": 8}
  res = []
@@ -170,6 +172,13 @@ if __name__ == "__main__":
      dc.metrics.Metric(dc.metrics.pearson_r2_score, mode="regression")
  ]

  model_dir = "/tmp/ani.pkl"

  if os.path.exists(model_dir):
    print("Restoring existing model...")
    model = dc.models.ANIRegression.load_from_dir(model_dir=model_dir)
  else:
    print("Fitting new model...")
    model = dc.models.ANIRegression(
        1,
        max_atoms,
@@ -178,12 +187,15 @@ if __name__ == "__main__":
        batch_size=batch_size,
        learning_rate=0.001,
        use_queue=False,
      model_dir="/tmp/ani.pkl",
        model_dir=model_dir,
        mode="regression")

    # For production, set nb_epoch to 100+
    model.fit(train_dataset, nb_epoch=1, checkpoint_interval=100)

    print("Saving model...")
    model.save()

  print("Evaluating model")
  train_scores = model.evaluate(train_dataset, metric, transformers)
  valid_scores = model.evaluate(valid_dataset, metric, transformers)
@@ -212,8 +224,7 @@ if __name__ == "__main__":
  print("Gradient of a single test set element:")
  print(model.grad_one(coords, atomic_nums))

  # currently broken
  # model.save()


  app.webapp.model = model
  app.webapp.run(host='0.0.0.0', debug=False)