Commit eef7bc2d authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

yapf again

parent cff6617e
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -290,9 +290,9 @@ class SupportGraphClassifier(Model):
  def predict_on_batch(self, support, test_batch):
    """Make predictions on batch of data."""
    n_samples = len(test_batch)
    padded_test_batch = NumpyDataset(*pad_batch(self.test_batch_size,
                                                test_batch.X, test_batch.y,
                                                test_batch.w, test_batch.ids))
    X, y, w, ids = pad_batch(self.test_batch_size, test_batch.X, test_batch.y,
                             test_batch.w, test_batch.ids)
    padded_test_batch = NumpyDataset(X, y, w, ids)
    feed_dict = self.construct_feed_dict(padded_test_batch, support)
    # Get scores
    pred, scores = self.sess.run(
@@ -305,9 +305,9 @@ class SupportGraphClassifier(Model):
  def predict_proba_on_batch(self, support, test_batch):
    """Make predictions on batch of data."""
    n_samples = len(test_batch)
    padded_test_batch = NumpyDataset(*pad_batch(self.test_batch_size,
                                                test_batch.X, test_batch.y,
                                                test_batch.w, test_batch.ids))
    X, y, w, ids = pad_batch(self.test_batch_size, test_batch.X, test_batch.y,
                             test_batch.w, test_batch.ids)
    padded_test_batch = NumpyDataset(X, y, w, ids)
    feed_dict = self.construct_feed_dict(padded_test_batch, support)
    # Get scores
    pred, scores = self.sess.run(