Commit 9e6155ff authored by peastman's avatar peastman
Browse files

Attempt at fixing travis failures

parent 18945093
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ class _DelaneyLoader(_MolnetLoader):
    if not os.path.exists(dataset_file):
      dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=self.data_dir)
    loader = dc.data.CSVLoader(
        tasks=DELANEY_TASKS, feature_field="smiles", featurizer=self.featurizer)
        tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer)
    return loader.create_dataset(dataset_file, shard_size=8192)


@@ -79,6 +79,6 @@ def load_delaney(
     molecular structure." Journal of chemical information and computer
     sciences 44.3 (2004): 1000-1005.
  """
  loader = _DelaneyLoader(featurizer, splitter, transformers, data_dir,
                          save_dir, **kwargs)
  return loader.load_dataset('delaney', DELANEY_TASKS, reload)
  loader = _DelaneyLoader(featurizer, splitter, transformers, DELANEY_TASKS,
                          data_dir, save_dir, **kwargs)
  return loader.load_dataset('delaney', reload)
+10 −8
Original line number Diff line number Diff line
@@ -85,7 +85,8 @@ class _MolnetLoader(object):
  def __init__(self, featurizer: Union[dc.feat.Featurizer, str],
               splitter: Union[dc.splits.Splitter, str, None],
               transformer_generators: List[Union[TransformerGenerator, str]],
               data_dir: Optional[str], save_dir: Optional[str], **kwargs):
               tasks: List[str], data_dir: Optional[str],
               save_dir: Optional[str], **kwargs):
    """Construct an object for loading a dataset.

    Parameters
@@ -102,6 +103,8 @@ class _MolnetLoader(object):
      the Transformers to apply to the data.  Each one is specified by a
      TransformerGenerator or, as a shortcut, one of the names from
      dc.molnet.transformers.
    tasks: List[str]
      the names of the tasks in the dataset
    data_dir: str
      a directory to save the raw data in
    save_dir: str
@@ -124,12 +127,13 @@ class _MolnetLoader(object):
        transformers[t.lower()] if isinstance(t, str) else t
        for t in transformer_generators
    ]
    self.tasks = list(tasks)
    self.data_dir = data_dir
    self.save_dir = save_dir
    self.args = kwargs

  def load_dataset(
      self, name: str, tasks: List[str], reload: bool
      self, name: str, reload: bool
  ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
    """Load the dataset.

@@ -137,8 +141,6 @@ class _MolnetLoader(object):
    ----------
    name: str
      the name of the dataset, used to identify the directory on disk
    tasks: List[str]
      the names of the tasks in this dataset
    reload: bool
      if True, the first call for a particular featurizer and splitter will cache
      the datasets to disk, and subsequent calls will reload the cached datasets.
@@ -160,12 +162,12 @@ class _MolnetLoader(object):
      if self.splitter is None:
        if os.path.exists(save_folder):
          transformers = dc.utils.data_utils.load_transformers(save_folder)
          return tasks, (DiskDataset(save_folder),), transformers
          return self.tasks, (DiskDataset(save_folder),), transformers
      else:
        loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk(
            save_folder)
        if all_dataset is not None:
          return tasks, all_dataset, transformers
          return self.tasks, all_dataset, transformers

    # Create the dataset

@@ -190,7 +192,7 @@ class _MolnetLoader(object):
      if reload and isinstance(dataset, DiskDataset):
        dataset.move(save_folder)
        dc.utils.data_utils.save_transformers(save_folder, transformers)
      return tasks, (dataset,), transformers
      return self.tasks, (dataset,), transformers

    for transformer in transformers:
      train = transformer.transform(train)
@@ -200,7 +202,7 @@ class _MolnetLoader(object):
        valid, DiskDataset) and isinstance(test, DiskDataset):
      dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test,
                                               transformers)
    return tasks, (train, valid, test), transformers
    return self.tasks, (train, valid, test), transformers

  def create_dataset(self) -> Dataset:
    """Subclasses must implement this to load the dataset."""