Commit 407db0e1 authored by peastman's avatar peastman
Browse files

Refactored molnet loader

parent 47006c5b
Loading
Loading
Loading
Loading
+2 −0
Original line number Original line Diff line number Diff line
@@ -37,6 +37,8 @@ from deepchem.molnet.load_function.material_datasets.load_perovskite import load
from deepchem.molnet.load_function.material_datasets.load_mp_formation_energy import load_mp_formation_energy
from deepchem.molnet.load_function.material_datasets.load_mp_formation_energy import load_mp_formation_energy
from deepchem.molnet.load_function.material_datasets.load_mp_metallicity import load_mp_metallicity
from deepchem.molnet.load_function.material_datasets.load_mp_metallicity import load_mp_metallicity


from deepchem.molnet.load_function.molnet_loader import featurizers, splitters, _MolnetLoader

from deepchem.molnet.dnasim import simulate_motif_density_localization
from deepchem.molnet.dnasim import simulate_motif_density_localization
from deepchem.molnet.dnasim import simulate_motif_counting
from deepchem.molnet.dnasim import simulate_motif_counting
from deepchem.molnet.dnasim import simple_motif_embedding
from deepchem.molnet.dnasim import simple_motif_embedding
+0 −17
Original line number Original line Diff line number Diff line
@@ -16,23 +16,6 @@ from deepchem.splits.splitters import Splitter


logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)


featurizers = {
    'ecfp': dc.feat.CircularFingerprint(size=1024),
    'graphconv': dc.feat.ConvMolFeaturizer(),
    'weave': dc.feat.WeaveFeaturizer(),
    'raw': dc.feat.RawFeaturizer(),
    'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
    'index': dc.splits.IndexSplitter(),
    'random': dc.splits.RandomSplitter(),
    'scaffold': dc.splits.ScaffoldSplitter(),
    'butina': dc.splits.ButinaSplitter(),
    'task': dc.splits.TaskSplitter(),
    'stratified': dc.splits.RandomStratifiedSplitter()
}



def get_defaults(module_name: str = None) -> Dict[str, Any]:
def get_defaults(module_name: str = None) -> Dict[str, Any]:
  """Get featurizers, transformers, and splitters.
  """Get featurizers, transformers, and splitters.
+34 −79
Original line number Original line Diff line number Diff line
@@ -4,13 +4,32 @@ Delaney dataset loader.
import os
import os
import logging
import logging
import deepchem as dc
import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from deepchem.molnet.load_function.molnet_loader import _MolnetLoader
from deepchem.data import Dataset
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union


logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)


DEFAULT_DIR = dc.utils.data_utils.get_data_dir()
DELANEY_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv"
DELANEY_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv"
DELANEY_TASKS = ['measured log solubility in mols per litre']


class _DelaneyLoader(_MolnetLoader):

  def create_dataset(self) -> Dataset:
    logger.info("About to featurize Delaney dataset.")
    dataset_file = os.path.join(self.data_dir, "delaney-processed.csv")
    if not os.path.exists(dataset_file):
      dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=self.data_dir)
    loader = dc.data.CSVLoader(
        tasks=DELANEY_TASKS, feature_field="smiles", featurizer=self.featurizer)
    return loader.create_dataset(dataset_file, shard_size=8192)

  def get_transformers(self, dataset: Dataset) -> List[dc.trans.Transformer]:
    return [
        dc.trans.NormalizationTransformer(
            transform_y=True, dataset=dataset, move_mean=self.args['move_mean'])
    ]




def load_delaney(
def load_delaney(
@@ -22,7 +41,7 @@ def load_delaney(
    save_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    **kwargs
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
  """Load delaney dataset
  """Load Delaney dataset


  The Delaney (ESOL) dataset a regression dataset containing structures and
  The Delaney (ESOL) dataset a regression dataset containing structures and
  water solubility data for 1128 compounds. The dataset is widely used to
  water solubility data for 1128 compounds. The dataset is widely used to
@@ -42,11 +61,11 @@ def load_delaney(
  ----------
  ----------
  featurizer: Featurizer or str
  featurizer: Featurizer or str
    the featurizer to use for processing the data.  Alternatively you can pass
    the featurizer to use for processing the data.  Alternatively you can pass
    one of the names from dc.molnet.defaults.featurizers as a shortcut.
    one of the names from dc.molnet.featurizers as a shortcut.
  splitter: Splitter or str
  splitter: Splitter or str
    the splitter to use for splitting the data into training, validation, and
    the splitter to use for splitting the data into training, validation, and
    test sets.  Alternatively you can pass one of the names from
    test sets.  Alternatively you can pass one of the names from
    dc.molnet.defaults.splitters as a shortcut.  If this is None, all the data
    dc.molnet.splitters as a shortcut.  If this is None, all the data
    will be included in a single dataset.
    will be included in a single dataset.
  reload: bool
  reload: bool
    if True, the first call for a particular featurizer and splitter will cache
    if True, the first call for a particular featurizer and splitter will cache
@@ -64,76 +83,12 @@ def load_delaney(
     molecular structure." Journal of chemical information and computer
     molecular structure." Journal of chemical information and computer
     sciences 44.3 (2004): 1000-1005.
     sciences 44.3 (2004): 1000-1005.
  """
  """
  if 'split' in kwargs:
  loader = _DelaneyLoader(
    splitter = kwargs['split']
      featurizer, splitter, data_dir, save_dir, move_mean=move_mean, **kwargs)
    logger.warning("'split' is deprecated.  Use 'splitter' instead.")
  featurizer_name = str(loader.featurizer)
  if isinstance(featurizer, str):
  splitter_name = 'None' if loader.splitter is None else str(loader.splitter)
    featurizer = dc.molnet.defaults.featurizers[featurizer.lower()]
  if isinstance(splitter, str):
    splitter = dc.molnet.defaults.splitters[splitter.lower()]
  if data_dir is None:
    data_dir = DEFAULT_DIR
  if save_dir is None:
    save_dir = DEFAULT_DIR
  tasks = ['measured log solubility in mols per litre']

  # Try to reload cached datasets.

  if reload:
    featurizer_name = str(featurizer)
    splitter_name = str(splitter)
  if not move_mean:
  if not move_mean:
    featurizer_name = featurizer_name + "_mean_unmoved"
    featurizer_name = featurizer_name + "_mean_unmoved"
    save_folder = os.path.join(save_dir, "delaney-featurized", featurizer_name,
  save_folder = os.path.join(loader.save_dir, "delaney-featurized",
                               splitter_name)
                             featurizer_name, splitter_name)
    if splitter is None:
  return loader.load_dataset(DELANEY_TASKS, save_folder, reload)
      if os.path.exists(save_folder):
        transformers = dc.utils.data_utils.load_transformers(save_folder)
        return tasks, (DiskDataset(save_folder),), transformers
    else:
      loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk(
          save_folder)
      if all_dataset is not None:
        return tasks, all_dataset, transformers

  # Featurize Delaney dataset

  logger.info("About to featurize Delaney dataset.")
  dataset_file = os.path.join(data_dir, "delaney-processed.csv")
  if not os.path.exists(dataset_file):
    dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=data_dir)
  loader = dc.data.CSVLoader(
      tasks=tasks, feature_field="smiles", featurizer=featurizer)
  dataset = loader.create_dataset(dataset_file, shard_size=8192)

  # Split and transform the dataset.

  if splitter is None:
    transformer_dataset: Dataset = dataset
  else:
    logger.info("About to split dataset with {} splitter.".format(
        splitter.__class__.__name__))
    train, valid, test = splitter.train_valid_test_split(dataset)
    transformer_dataset = train
  transformers = [
      dc.trans.NormalizationTransformer(
          transform_y=True, dataset=transformer_dataset, move_mean=move_mean)
  ]
  logger.info("About to transform data.")
  if splitter is None:
    for transformer in transformers:
      dataset = transformer.transform(dataset)
    if reload and isinstance(dataset, DiskDataset):
      dataset.move(save_folder)
      dc.utils.data_utils.save_transformers(save_folder, transformers)
    return tasks, (dataset,), transformers

  for transformer in transformers:
    train = transformer.transform(train)
    valid = transformer.transform(valid)
    test = transformer.transform(test)
  if reload and isinstance(train, DiskDataset) and isinstance(
      valid, DiskDataset) and isinstance(test, DiskDataset):
    dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test,
                                             transformers)
  return tasks, (train, valid, test), transformers
+139 −0
Original line number Original line Diff line number Diff line
"""
Common code for loading MoleculeNet datasets.
"""
import os
import logging
import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from typing import List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

featurizers = {
    'ecfp': dc.feat.CircularFingerprint(size=1024),
    'graphconv': dc.feat.ConvMolFeaturizer(),
    'weave': dc.feat.WeaveFeaturizer(),
    'raw': dc.feat.RawFeaturizer(),
    'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
    'index': dc.splits.IndexSplitter(),
    'random': dc.splits.RandomSplitter(),
    'scaffold': dc.splits.ScaffoldSplitter(),
    'butina': dc.splits.ButinaSplitter(),
    'task': dc.splits.TaskSplitter(),
    'stratified': dc.splits.RandomStratifiedSplitter()
}


class _MolnetLoader(object):
  """The class provides common functionality used by many molnet loader functions.
  It is an abstract class.  Subclasses implement loading of particular datasets.
  """

  def __init__(self, featurizer: Union[dc.feat.Featurizer, str],
               splitter: Union[dc.splits.Splitter, str, None],
               data_dir: Optional[str], save_dir: Optional[str], **kwargs):
    """Construct an object for loading a dataset.

    Parameters
    ----------
    featurizer: Featurizer or str
      the featurizer to use for processing the data.  Alternatively you can pass
      one of the names from dc.molnet.featurizers as a shortcut.
    splitter: Splitter or str
      the splitter to use for splitting the data into training, validation, and
      test sets.  Alternatively you can pass one of the names from
      dc.molnet.splitters as a shortcut.  If this is None, all the data
      will be included in a single dataset.
    data_dir: str
      a directory to save the raw data in
    save_dir: str
      a directory to save the dataset in
    """
    if 'split' in kwargs:
      splitter = kwargs['split']
      logger.warning("'split' is deprecated.  Use 'splitter' instead.")
    if isinstance(featurizer, str):
      featurizer = featurizers[featurizer.lower()]
    if isinstance(splitter, str):
      splitter = splitters[splitter.lower()]
    if data_dir is None:
      data_dir = dc.utils.data_utils.get_data_dir()
    if save_dir is None:
      save_dir = dc.utils.data_utils.get_data_dir()
    self.featurizer = featurizer
    self.splitter = splitter
    self.data_dir = data_dir
    self.save_dir = save_dir
    self.args = kwargs

  def load_dataset(
      self, tasks: List[str], save_folder: str, reload: bool
  ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
    """Load the dataset.

    Parameters
    ----------
    tasks: List[str]
      the names of the tasks in this dataset
    save_folder: str
      the directory in which the dataset should be saved
    reload: bool
      if True, the first call for a particular featurizer and splitter will cache
      the datasets to disk, and subsequent calls will reload the cached datasets.
    """
    # Try to reload cached datasets.

    if reload:
      if self.splitter is None:
        if os.path.exists(save_folder):
          transformers = dc.utils.data_utils.load_transformers(save_folder)
          return tasks, (DiskDataset(save_folder),), transformers
      else:
        loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk(
            save_folder)
        if all_dataset is not None:
          return tasks, all_dataset, transformers

    # Create the dataset

    dataset = self.create_dataset()

    # Split and transform the dataset.

    if self.splitter is None:
      transformer_dataset: Dataset = dataset
    else:
      logger.info("About to split dataset with {} splitter.".format(
          self.splitter.__class__.__name__))
      train, valid, test = self.splitter.train_valid_test_split(dataset)
      transformer_dataset = train
    transformers = self.get_transformers(transformer_dataset)
    logger.info("About to transform data.")
    if self.splitter is None:
      for transformer in transformers:
        dataset = transformer.transform(dataset)
      if reload and isinstance(dataset, DiskDataset):
        dataset.move(save_folder)
        dc.utils.data_utils.save_transformers(save_folder, transformers)
      return tasks, (dataset,), transformers

    for transformer in transformers:
      train = transformer.transform(train)
      valid = transformer.transform(valid)
      test = transformer.transform(test)
    if reload and isinstance(train, DiskDataset) and isinstance(
        valid, DiskDataset) and isinstance(test, DiskDataset):
      dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test,
                                               transformers)
    return tasks, (train, valid, test), transformers

  def create_dataset(self) -> Dataset:
    """Subclasses must implement this to load the dataset."""
    raise NotImplementedError()

  def get_transformers(self, dataset: Dataset) -> List[dc.trans.Transformer]:
    """Subclasses must implement this to create the transformers for the dataset."""
    raise NotImplementedError()