Commit 18945093 authored by peastman's avatar peastman
Browse files

Changed how molnet handles Transformers

parent 5c55f230
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ from deepchem.molnet.load_function.material_datasets.load_perovskite import load
from deepchem.molnet.load_function.material_datasets.load_mp_formation_energy import load_mp_formation_energy
from deepchem.molnet.load_function.material_datasets.load_mp_metallicity import load_mp_metallicity

from deepchem.molnet.load_function.molnet_loader import featurizers, splitters, _MolnetLoader
from deepchem.molnet.load_function.molnet_loader import featurizers, splitters, transformers, TransformerGenerator, _MolnetLoader

from deepchem.molnet.dnasim import simulate_motif_density_localization
from deepchem.molnet.dnasim import simulate_motif_counting
+9 −19
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ Delaney dataset loader.
import os
import logging
import deepchem as dc
from deepchem.molnet.load_function.molnet_loader import _MolnetLoader
from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader
from deepchem.data import Dataset
from typing import List, Optional, Tuple, Union

@@ -25,18 +25,12 @@ class _DelaneyLoader(_MolnetLoader):
        tasks=DELANEY_TASKS, feature_field="smiles", featurizer=self.featurizer)
    return loader.create_dataset(dataset_file, shard_size=8192)

  def get_transformers(self, dataset: Dataset) -> List[dc.trans.Transformer]:
    return [
        dc.trans.NormalizationTransformer(
            transform_y=True, dataset=dataset, move_mean=self.args['move_mean'])
    ]


def load_delaney(
    featurizer: Union[dc.feat.Featurizer, str] = 'ECFP',
    splitter: Union[dc.splits.Splitter, str, None] = 'scaffold',
    transformers: List[Union[TransformerGenerator, str]] = ['normalization'],
    reload: bool = True,
    move_mean: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    **kwargs
@@ -67,11 +61,13 @@ def load_delaney(
    test sets.  Alternatively you can pass one of the names from
    dc.molnet.splitters as a shortcut.  If this is None, all the data
    will be included in a single dataset.
  transformers: list of TransformerGenerators or strings
    the Transformers to apply to the data.  Each one is specified by a
    TransformerGenerator or, as a shortcut, one of the names from
    dc.molnet.transformers.
  reload: bool
    if True, the first call for a particular featurizer and splitter will cache
    the datasets to disk, and subsequent calls will reload the cached datasets.
  move_mean: bool
    if True, all the data is shifted so the training set has a mean of zero.
  data_dir: str
    a directory to save the raw data in
  save_dir: str
@@ -83,12 +79,6 @@ def load_delaney(
     molecular structure." Journal of chemical information and computer
     sciences 44.3 (2004): 1000-1005.
  """
  loader = _DelaneyLoader(
      featurizer, splitter, data_dir, save_dir, move_mean=move_mean, **kwargs)
  featurizer_name = str(loader.featurizer)
  splitter_name = 'None' if loader.splitter is None else str(loader.splitter)
  if not move_mean:
    featurizer_name = featurizer_name + "_mean_unmoved"
  save_folder = os.path.join(loader.save_dir, "delaney-featurized",
                             featurizer_name, splitter_name)
  return loader.load_dataset(DELANEY_TASKS, save_folder, reload)
  loader = _DelaneyLoader(featurizer, splitter, transformers, data_dir,
                          save_dir, **kwargs)
  return loader.load_dataset('delaney', DELANEY_TASKS, reload)
+77 −9
Original line number Diff line number Diff line
@@ -5,10 +5,47 @@ import os
import logging
import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Type, Union

logger = logging.getLogger(__name__)


class TransformerGenerator(object):
  """Create Transformers for Datasets.

  When loading molnet datasets, you cannot directly pass in Transformers
  to use because many Transformers require the Dataset they will be applied to
  as a constructor argument.  Instead you pass in TransformerGenerator objects
  which can create the Transformers once the Dataset is loaded.
  """

  def __init__(self, transformer_class: Type[dc.trans.Transformer], **kwargs):
    """Construct an object for creating Transformers.

    Parameters
    ----------
    transformer_class: Type[Transformer]
      the class of Transformer to create
    kwargs:
      any additional arguments are passed to the Transformer's constructor
    """
    self.transformer_class = transformer_class
    self.kwargs = kwargs

  def create_transformer(self, dataset: Dataset) -> dc.trans.Transformer:
    """Construct a Transformer for a Dataset."""
    return self.transformer_class(dataset=dataset, **self.kwargs)

  def get_directory_name(self) -> str:
    """Get a name for directories on disk describing this Transformer."""
    name = self.transformer_class.__name__
    for key, value in self.kwargs.items():
      if isinstance(value, list):
        continue
      name += '_' + key + '_' + str(value)
    return name


featurizers = {
    'ecfp': dc.feat.CircularFingerprint(size=1024),
    'graphconv': dc.feat.ConvMolFeaturizer(),
@@ -26,6 +63,19 @@ splitters = {
    'stratified': dc.splits.RandomStratifiedSplitter()
}

transformers = {
    'balancing':
    TransformerGenerator(dc.trans.BalancingTransformer),
    'normalization':
    TransformerGenerator(dc.trans.NormalizationTransformer, transform_y=True),
    'minmax':
    TransformerGenerator(dc.trans.MinMaxTransformer, transform_y=True),
    'clipping':
    TransformerGenerator(dc.trans.ClippingTransformer, transform_y=True),
    'log':
    TransformerGenerator(dc.trans.LogTransformer, transform_y=True)
}


class _MolnetLoader(object):
  """The class provides common functionality used by many molnet loader functions.
@@ -34,6 +84,7 @@ class _MolnetLoader(object):

  def __init__(self, featurizer: Union[dc.feat.Featurizer, str],
               splitter: Union[dc.splits.Splitter, str, None],
               transformer_generators: List[Union[TransformerGenerator, str]],
               data_dir: Optional[str], save_dir: Optional[str], **kwargs):
    """Construct an object for loading a dataset.

@@ -47,6 +98,10 @@ class _MolnetLoader(object):
      test sets.  Alternatively you can pass one of the names from
      dc.molnet.splitters as a shortcut.  If this is None, all the data
      will be included in a single dataset.
    transformer_generators: list of TransformerGenerators or strings
      the Transformers to apply to the data.  Each one is specified by a
      TransformerGenerator or, as a shortcut, one of the names from
      dc.molnet.transformers.
    data_dir: str
      a directory to save the raw data in
    save_dir: str
@@ -65,25 +120,40 @@ class _MolnetLoader(object):
      save_dir = dc.utils.data_utils.get_data_dir()
    self.featurizer = featurizer
    self.splitter = splitter
    self.transformers = [
        transformers[t.lower()] if isinstance(t, str) else t
        for t in transformer_generators
    ]
    self.data_dir = data_dir
    self.save_dir = save_dir
    self.args = kwargs

  def load_dataset(
      self, tasks: List[str], save_folder: str, reload: bool
      self, name: str, tasks: List[str], reload: bool
  ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
    """Load the dataset.

    Parameters
    ----------
    name: str
      the name of the dataset, used to identify the directory on disk
    tasks: List[str]
      the names of the tasks in this dataset
    save_folder: str
      the directory in which the dataset should be saved
    reload: bool
      if True, the first call for a particular featurizer and splitter will cache
      the datasets to disk, and subsequent calls will reload the cached datasets.
    """
    # Build the path to the dataset on disk.

    featurizer_name = str(self.featurizer)
    splitter_name = 'None' if self.splitter is None else str(self.splitter)
    save_folder = os.path.join(self.save_dir, name + "-featurized",
                               featurizer_name, splitter_name)
    if len(self.transformers) > 0:
      transformer_name = '_'.join(
          t.get_directory_name() for t in self.transformers)
      save_folder = os.path.join(save_folder, transformer_name)

    # Try to reload cached datasets.

    if reload:
@@ -110,7 +180,9 @@ class _MolnetLoader(object):
          self.splitter.__class__.__name__))
      train, valid, test = self.splitter.train_valid_test_split(dataset)
      transformer_dataset = train
    transformers = self.get_transformers(transformer_dataset)
    transformers = [
        t.create_transformer(transformer_dataset) for t in self.transformers
    ]
    logger.info("About to transform data.")
    if self.splitter is None:
      for transformer in transformers:
@@ -133,7 +205,3 @@ class _MolnetLoader(object):
  def create_dataset(self) -> Dataset:
    """Subclasses must implement this to load the dataset."""
    raise NotImplementedError()

  def get_transformers(self, dataset: Dataset) -> List[dc.trans.Transformer]:
    """Subclasses must implement this to create the transformers for the dataset."""
    raise NotImplementedError()