Commit 3e3438cc authored by peastman's avatar peastman
Browse files

More fixes

parent d123a4d4
Loading
Loading
Loading
Loading
+32 −36
Original line number Diff line number Diff line
@@ -605,7 +605,6 @@ class GraphConvTensorGraph(TensorGraph):
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      with tf.Session() as sess:
      out_tensors = [x.out_tensor for x in self.outputs]
      results = []
      for feed_dict in generator:
@@ -614,7 +613,7 @@ class GraphConvTensorGraph(TensorGraph):
            for k, v in six.iteritems(feed_dict)
        }
        feed_dict[self._training_placeholder] = 1.0  ##
          result = np.array(sess.run(out_tensors, feed_dict=feed_dict))
        result = np.array(self.session.run(out_tensors, feed_dict=feed_dict))
        if len(result.shape) == 3:
          result = np.transpose(result, axes=[1, 0, 2])
        if len(transformers) > 0:
@@ -866,9 +865,6 @@ class MPNNTensorGraph(TensorGraph):
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      with tf.Session() as sess:
        saver = tf.train.Saver()
        self._initialize_weights(sess, saver)
      out_tensors = [x.out_tensor for x in self.outputs]
      results = []
      for feed_dict in generator:
@@ -879,7 +875,7 @@ class MPNNTensorGraph(TensorGraph):
            for k, v in six.iteritems(feed_dict)
        }
        feed_dict[self._training_placeholder] = 0.0
          result = np.array(sess.run(out_tensors, feed_dict=feed_dict))
        result = np.array(self.session.run(out_tensors, feed_dict=feed_dict))
        if len(result.shape) == 3:
          result = np.transpose(result, axes=[1, 0, 2])
        result = undo_transforms(result, transformers)