Commit 8f1917f5 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Cleanup and some test fixes

parent bb63dd70
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -686,7 +686,8 @@ class Dataset(object):
    """Computes and returns statistics of this dataset

    This function assumes that the first task of a dataset holds the energy for
    an input system, and that the remaining tasks holds the gradient for the system.
    an input system, and that the remaining tasks holds the gradient for the
    system.

    TODO(rbharath, joegomes): It is unclear whether this should be a Dataset
    function. Might get refactored out.
+1 −1
Original line number Diff line number Diff line
@@ -156,7 +156,7 @@ class Model(object):
    y_pred = np.reshape(y_pred, (n_samples, n_tasks))
    # Special case to handle singletasks.
    if n_tasks == 1:
      y_pred = np.squeeze(y_pred)
      y_pred = np.reshape(y_pred, (n_samples,)) 
    return y_pred

  def predict_grad(self, dataset, transformers=[]):
+3 −3
Original line number Diff line number Diff line
@@ -153,7 +153,7 @@ class SingletaskToMultitask(Model):
    Concatenates results from all singletask models.
    """
    n_tasks = len(self.tasks)
    n_samples = X.shape[0]
    n_samples = len(dataset) 
    y_pred = np.zeros((n_samples, n_tasks, n_classes))
    for ind, task in enumerate(self.tasks):
      if self.store_in_memory:
@@ -165,8 +165,8 @@ class SingletaskToMultitask(Model):
            verbosity=self.verbosity)
        task_model.reload()

      y_pred[:, ind] = task_model.predict_proba_on_batch(
          dataset, transformers, n_classes)
      y_pred[:, ind] = np.squeeze(task_model.predict_proba(
          dataset, transformers, n_classes))
    return y_pred

  def save(self):
+2 −0
Original line number Diff line number Diff line
@@ -56,6 +56,8 @@ class Evaluator(object):
      csvfile: Open file object.
    """
    mol_ids = self.dataset.get_ids()
    n_tasks = len(self.task_names)
    y_preds = np.reshape(y_preds, (len(y_preds), n_tasks))
    assert len(y_preds) == len(mol_ids)
    with open(csv_out, "wb") as csvfile:
      csvwriter = csv.writer(csvfile)