Commit 7010d1a8 authored by Joseph Gomes's avatar Joseph Gomes
Browse files

Update evaluate for pad_batch

parent 7442cceb
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -66,7 +66,7 @@ class Evaluator(object):
        csvwriter.writerow([mol_id] + list(y_pred))

  def compute_model_performance(self, metrics, csv_out=None, stats_out=None,
                                threshold=None):
                                threshold=None, pad_batch=False):
    """
    Computes statistics of model on test data and saves results to csv.
    """
@@ -79,11 +79,11 @@ class Evaluator(object):
    else:
      mode = metrics[0].mode
    if mode == "classification":
      y_pred = self.model.predict_proba(self.dataset, self.output_transformers)
      y_pred = self.model.predict_proba(self.dataset, self.output_transformers, pad_batch=pad_batch)
      y_pred_print = self.model.predict(
          self.dataset, self.output_transformers).astype(int)
    else:
      y_pred = self.model.predict(self.dataset, self.output_transformers)
      y_pred = self.model.predict(self.dataset, self.output_transformers, pad_batch=pad_batch)
      y_pred_print = y_pred
    multitask_scores = {}