Commit 0832fa80 authored by Peter Eastman's avatar Peter Eastman
Browse files

I do not like yapf

parent 331239d4
Loading
Loading
Loading
Loading
+13 −11
Original line number Diff line number Diff line
@@ -15,12 +15,14 @@ from deepchem.models import Model
from deepchem.data import DiskDataset
from deepchem.trans import undo_transforms


class SingletaskToMultitask(Model):
  """
  Convenience class to let singletask models be fit on multitask data.

  Warning: This current implementation is only functional for sklearn models. 
  """

  def __init__(self, tasks, model_builder, model_dir=None, verbose=True):
    super().__init__(self, model_dir=model_dir, verbose=verbose)
    self.tasks = tasks
@@ -45,8 +47,8 @@ class SingletaskToMultitask(Model):
      task_data_dirs.append(task_data_dir)
    task_datasets = self._to_singletask(dataset, task_data_dirs)
    for task, task_dataset in zip(self.tasks, task_datasets):
      log("Dataset for task %s has shape %s"
          % (task, str(task_dataset.get_shape())), self.verbose)
      log("Dataset for task %s has shape %s" %
          (task, str(task_dataset.get_shape())), self.verbose)
    return task_datasets

  @staticmethod
@@ -55,8 +57,10 @@ class SingletaskToMultitask(Model):
    tasks = dataset.get_task_names()
    assert len(tasks) == len(task_dirs)
    log("Splitting multitask dataset into singletask datasets", dataset.verbose)
    task_datasets = [DiskDataset.create_dataset([], task_dirs[task_num], [task])
                    for (task_num, task) in enumerate(tasks)]
    task_datasets = [
        DiskDataset.create_dataset([], task_dirs[task_num], [task])
        for (task_num, task) in enumerate(tasks)
    ]
    #task_metadata_rows = {task: [] for task in tasks}
    for shard_num, (X, y, w, ids) in enumerate(dataset.itershards()):
      log("Processing shard %d" % shard_num, dataset.verbose)
@@ -78,7 +82,6 @@ class SingletaskToMultitask(Model):

    return task_datasets


  def fit(self, dataset, **kwargs):
    """
    Updates all singletask models with new information.
@@ -91,8 +94,7 @@ class SingletaskToMultitask(Model):
    task_datasets = self._create_task_datasets(dataset)
    for ind, task in enumerate(self.tasks):
      log("Fitting model for task %s" % task, self.verbose)
      task_model = self.model_builder(
          self.task_model_dirs[task])
      task_model = self.model_builder(self.task_model_dirs[task])
      task_model.fit(task_datasets[ind], **kwargs)
      task_model.save()

@@ -150,8 +152,8 @@ class SingletaskToMultitask(Model):
      task_model = self.model_builder(self.task_model_dirs[task])
      task_model.reload()

      y_pred[:, ind] = np.squeeze(task_model.predict_proba(
          dataset, transformers, n_classes))
      y_pred[:, ind] = np.squeeze(
          task_model.predict_proba(dataset, transformers, n_classes))
    return y_pred

  def save(self):