Unverified Commit 31c9b6bf authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2213 from peastman/molnet

[WIP] Updated API for MoleculeNet loader functions
parents 19eeac11 9e6155ff
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -37,6 +37,8 @@ from deepchem.molnet.load_function.material_datasets.load_perovskite import load
from deepchem.molnet.load_function.material_datasets.load_mp_formation_energy import load_mp_formation_energy
from deepchem.molnet.load_function.material_datasets.load_mp_metallicity import load_mp_metallicity

from deepchem.molnet.load_function.molnet_loader import featurizers, splitters, transformers, TransformerGenerator, _MolnetLoader

from deepchem.molnet.dnasim import simulate_motif_density_localization
from deepchem.molnet.dnasim import simulate_motif_counting
from deepchem.molnet.dnasim import simple_motif_embedding
+1 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ import logging
import json
from typing import Dict, List, Any

import deepchem as dc
from deepchem.feat.base_classes import Featurizer
from deepchem.trans.transformers import Transformer
from deepchem.splits.splitters import Splitter
+54 −98
Original line number Diff line number Diff line
@@ -3,29 +3,46 @@ Delaney dataset loader.
"""
import os
import logging
import deepchem
import deepchem as dc
from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader
from deepchem.data import Dataset
from typing import List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

DEFAULT_DIR = deepchem.utils.data_utils.get_data_dir()
DELANEY_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv"
DELANEY_TASKS = ['measured log solubility in mols per litre']


def load_delaney(featurizer='ECFP',
                 split='index',
                 reload=True,
                 move_mean=True,
                 data_dir=None,
                 save_dir=None,
                 **kwargs):
  """Load delaney dataset
class _DelaneyLoader(_MolnetLoader):

  def create_dataset(self) -> Dataset:
    logger.info("About to featurize Delaney dataset.")
    dataset_file = os.path.join(self.data_dir, "delaney-processed.csv")
    if not os.path.exists(dataset_file):
      dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=self.data_dir)
    loader = dc.data.CSVLoader(
        tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer)
    return loader.create_dataset(dataset_file, shard_size=8192)


def load_delaney(
    featurizer: Union[dc.feat.Featurizer, str] = 'ECFP',
    splitter: Union[dc.splits.Splitter, str, None] = 'scaffold',
    transformers: List[Union[TransformerGenerator, str]] = ['normalization'],
    reload: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
  """Load Delaney dataset

  The Delaney (ESOL) dataset a regression dataset containing structures and
  water solubility data for 1128 compounds. The dataset is widely used to
  validate machine learning models on estimating solubility directly from
  molecular structures (as encoded in SMILES strings).

  Random splitting is recommended for this dataset.
  Scaffold splitting is recommended for this dataset.

  The raw data csv file contains columns below:

@@ -34,95 +51,34 @@ def load_delaney(featurizer='ECFP',
  - "measured log solubility in mols per litre" - Log-scale water solubility
    of the compound, used as label

  Parameters
  ----------
  featurizer: Featurizer or str
    the featurizer to use for processing the data.  Alternatively you can pass
    one of the names from dc.molnet.featurizers as a shortcut.
  splitter: Splitter or str
    the splitter to use for splitting the data into training, validation, and
    test sets.  Alternatively you can pass one of the names from
    dc.molnet.splitters as a shortcut.  If this is None, all the data
    will be included in a single dataset.
  transformers: list of TransformerGenerators or strings
    the Transformers to apply to the data.  Each one is specified by a
    TransformerGenerator or, as a shortcut, one of the names from
    dc.molnet.transformers.
  reload: bool
    if True, the first call for a particular featurizer and splitter will cache
    the datasets to disk, and subsequent calls will reload the cached datasets.
  data_dir: str
    a directory to save the raw data in
  save_dir: str
    a directory to save the dataset in

  References
  ----------
  .. [1] Delaney, John S. "ESOL: estimating aqueous solubility directly from
     molecular structure." Journal of chemical information and computer
     sciences 44.3 (2004): 1000-1005.
  """
  # Featurize Delaney dataset
  logger.info("About to featurize Delaney dataset.")
  if data_dir is None:
    data_dir = DEFAULT_DIR
  if save_dir is None:
    save_dir = DEFAULT_DIR
  if reload:
    save_folder = os.path.join(save_dir, "delaney-featurized")
    if not move_mean:
      save_folder = os.path.join(save_folder, str(featurizer) + "_mean_unmoved")
    else:
      save_folder = os.path.join(save_folder, str(featurizer))

    if featurizer == "smiles2img":
      img_spec = kwargs.get("img_spec", "std")
      save_folder = os.path.join(save_folder, img_spec)
    save_folder = os.path.join(save_folder, str(split))

  dataset_file = os.path.join(data_dir, "delaney-processed.csv")

  if not os.path.exists(dataset_file):
    deepchem.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=data_dir)

  delaney_tasks = ['measured log solubility in mols per litre']
  if reload:
    loaded, all_dataset, transformers = deepchem.utils.data_utils.load_dataset_from_disk(
        save_folder)
    if loaded:
      return delaney_tasks, all_dataset, transformers

  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
    featurizer = deepchem.feat.ConvMolFeaturizer()
  elif featurizer == 'Weave':
    featurizer = deepchem.feat.WeaveFeaturizer()
  elif featurizer == 'Raw':
    featurizer = deepchem.feat.RawFeaturizer()
  elif featurizer == "smiles2img":
    img_spec = kwargs.get("img_spec", "std")
    img_size = kwargs.get("img_size", 80)
    res = kwargs.get("res", 0.5)
    featurizer = deepchem.feat.SmilesToImage(
        img_size=img_size, img_spec=img_spec, res=res)

  loader = deepchem.data.CSVLoader(
      tasks=delaney_tasks, feature_field="smiles", featurizer=featurizer)
  dataset = loader.create_dataset(dataset_file, shard_size=8192)

  if split is None:
    transformers = [
        deepchem.trans.NormalizationTransformer(
            transform_y=True, dataset=dataset, move_mean=move_mean)
    ]

    logger.info("Split is None, about to transform data")
    for transformer in transformers:
      dataset = transformer.transform(dataset)

    return delaney_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'stratified': deepchem.splits.SingletaskStratifiedSplitter()
  }
  splitter = splitters[split]
  logger.info("About to split dataset with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)

  transformers = [
      deepchem.trans.NormalizationTransformer(
          transform_y=True, dataset=train, move_mean=move_mean)
  ]

  logger.info("About to transform data.")
  for transformer in transformers:
    train = transformer.transform(train)
    valid = transformer.transform(valid)
    test = transformer.transform(test)

  if reload:
    deepchem.utils.data_utils.save_dataset_to_disk(save_folder, train, valid,
                                                   test, transformers)
  return delaney_tasks, (train, valid, test), transformers
  loader = _DelaneyLoader(featurizer, splitter, transformers, DELANEY_TASKS,
                          data_dir, save_dir, **kwargs)
  return loader.load_dataset('delaney', reload)
+209 −0
Original line number Diff line number Diff line
"""
Common code for loading MoleculeNet datasets.
"""
import os
import logging
import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from typing import List, Optional, Tuple, Type, Union

logger = logging.getLogger(__name__)


class TransformerGenerator(object):
  """Create Transformers for Datasets.

  When loading molnet datasets, you cannot directly pass in Transformers
  to use because many Transformers require the Dataset they will be applied to
  as a constructor argument.  Instead you pass in TransformerGenerator objects
  which can create the Transformers once the Dataset is loaded.
  """

  def __init__(self, transformer_class: Type[dc.trans.Transformer], **kwargs):
    """Construct an object for creating Transformers.

    Parameters
    ----------
    transformer_class: Type[Transformer]
      the class of Transformer to create
    kwargs:
      any additional arguments are passed to the Transformer's constructor
    """
    self.transformer_class = transformer_class
    self.kwargs = kwargs

  def create_transformer(self, dataset: Dataset) -> dc.trans.Transformer:
    """Construct a Transformer for a Dataset."""
    return self.transformer_class(dataset=dataset, **self.kwargs)

  def get_directory_name(self) -> str:
    """Get a name for directories on disk describing this Transformer."""
    name = self.transformer_class.__name__
    for key, value in self.kwargs.items():
      if isinstance(value, list):
        continue
      name += '_' + key + '_' + str(value)
    return name


featurizers = {
    'ecfp': dc.feat.CircularFingerprint(size=1024),
    'graphconv': dc.feat.ConvMolFeaturizer(),
    'weave': dc.feat.WeaveFeaturizer(),
    'raw': dc.feat.RawFeaturizer(),
    'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
    'index': dc.splits.IndexSplitter(),
    'random': dc.splits.RandomSplitter(),
    'scaffold': dc.splits.ScaffoldSplitter(),
    'butina': dc.splits.ButinaSplitter(),
    'task': dc.splits.TaskSplitter(),
    'stratified': dc.splits.RandomStratifiedSplitter()
}

transformers = {
    'balancing':
    TransformerGenerator(dc.trans.BalancingTransformer),
    'normalization':
    TransformerGenerator(dc.trans.NormalizationTransformer, transform_y=True),
    'minmax':
    TransformerGenerator(dc.trans.MinMaxTransformer, transform_y=True),
    'clipping':
    TransformerGenerator(dc.trans.ClippingTransformer, transform_y=True),
    'log':
    TransformerGenerator(dc.trans.LogTransformer, transform_y=True)
}


class _MolnetLoader(object):
  """The class provides common functionality used by many molnet loader functions.
  It is an abstract class.  Subclasses implement loading of particular datasets.
  """

  def __init__(self, featurizer: Union[dc.feat.Featurizer, str],
               splitter: Union[dc.splits.Splitter, str, None],
               transformer_generators: List[Union[TransformerGenerator, str]],
               tasks: List[str], data_dir: Optional[str],
               save_dir: Optional[str], **kwargs):
    """Construct an object for loading a dataset.

    Parameters
    ----------
    featurizer: Featurizer or str
      the featurizer to use for processing the data.  Alternatively you can pass
      one of the names from dc.molnet.featurizers as a shortcut.
    splitter: Splitter or str
      the splitter to use for splitting the data into training, validation, and
      test sets.  Alternatively you can pass one of the names from
      dc.molnet.splitters as a shortcut.  If this is None, all the data
      will be included in a single dataset.
    transformer_generators: list of TransformerGenerators or strings
      the Transformers to apply to the data.  Each one is specified by a
      TransformerGenerator or, as a shortcut, one of the names from
      dc.molnet.transformers.
    tasks: List[str]
      the names of the tasks in the dataset
    data_dir: str
      a directory to save the raw data in
    save_dir: str
      a directory to save the dataset in
    """
    if 'split' in kwargs:
      splitter = kwargs['split']
      logger.warning("'split' is deprecated.  Use 'splitter' instead.")
    if isinstance(featurizer, str):
      featurizer = featurizers[featurizer.lower()]
    if isinstance(splitter, str):
      splitter = splitters[splitter.lower()]
    if data_dir is None:
      data_dir = dc.utils.data_utils.get_data_dir()
    if save_dir is None:
      save_dir = dc.utils.data_utils.get_data_dir()
    self.featurizer = featurizer
    self.splitter = splitter
    self.transformers = [
        transformers[t.lower()] if isinstance(t, str) else t
        for t in transformer_generators
    ]
    self.tasks = list(tasks)
    self.data_dir = data_dir
    self.save_dir = save_dir
    self.args = kwargs

  def load_dataset(
      self, name: str, reload: bool
  ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
    """Load the dataset.

    Parameters
    ----------
    name: str
      the name of the dataset, used to identify the directory on disk
    reload: bool
      if True, the first call for a particular featurizer and splitter will cache
      the datasets to disk, and subsequent calls will reload the cached datasets.
    """
    # Build the path to the dataset on disk.

    featurizer_name = str(self.featurizer)
    splitter_name = 'None' if self.splitter is None else str(self.splitter)
    save_folder = os.path.join(self.save_dir, name + "-featurized",
                               featurizer_name, splitter_name)
    if len(self.transformers) > 0:
      transformer_name = '_'.join(
          t.get_directory_name() for t in self.transformers)
      save_folder = os.path.join(save_folder, transformer_name)

    # Try to reload cached datasets.

    if reload:
      if self.splitter is None:
        if os.path.exists(save_folder):
          transformers = dc.utils.data_utils.load_transformers(save_folder)
          return self.tasks, (DiskDataset(save_folder),), transformers
      else:
        loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk(
            save_folder)
        if all_dataset is not None:
          return self.tasks, all_dataset, transformers

    # Create the dataset

    dataset = self.create_dataset()

    # Split and transform the dataset.

    if self.splitter is None:
      transformer_dataset: Dataset = dataset
    else:
      logger.info("About to split dataset with {} splitter.".format(
          self.splitter.__class__.__name__))
      train, valid, test = self.splitter.train_valid_test_split(dataset)
      transformer_dataset = train
    transformers = [
        t.create_transformer(transformer_dataset) for t in self.transformers
    ]
    logger.info("About to transform data.")
    if self.splitter is None:
      for transformer in transformers:
        dataset = transformer.transform(dataset)
      if reload and isinstance(dataset, DiskDataset):
        dataset.move(save_folder)
        dc.utils.data_utils.save_transformers(save_folder, transformers)
      return self.tasks, (dataset,), transformers

    for transformer in transformers:
      train = transformer.transform(train)
      valid = transformer.transform(valid)
      test = transformer.transform(test)
    if reload and isinstance(train, DiskDataset) and isinstance(
        valid, DiskDataset) and isinstance(test, DiskDataset):
      dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test,
                                               transformers)
    return self.tasks, (train, valid, test), transformers

  def create_dataset(self) -> Dataset:
    """Subclasses must implement this to load the dataset."""
    raise NotImplementedError()
+53 −19
Original line number Diff line number Diff line
"""
Contains an abstract base class that supports chemically aware data splits.
"""
import inspect
import os
import random
import tempfile
@@ -13,6 +14,7 @@ import pandas as pd

import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from deepchem.utils import get_print_threshold

logger = logging.getLogger(__name__)

@@ -103,8 +105,7 @@ class Splitter(object):
      train_ds_base = DiskDataset.merge(update_train_base_merge)
    return list(zip(train_datasets, cv_datasets))

  def train_valid_test_split(
      self,
  def train_valid_test_split(self,
                             dataset: Dataset,
                             train_dir: Optional[str] = None,
                             valid_dir: Optional[str] = None,
@@ -114,7 +115,7 @@ class Splitter(object):
                             frac_test: float = 0.1,
                             seed: Optional[int] = None,
                             log_every_n: int = 1000,
      **kwargs) -> Tuple[Dataset, Optional[Dataset], Dataset]:
                             **kwargs) -> Tuple[Dataset, Dataset, Dataset]:
    """ Splits self into train/validation/test sets.

    Returns Dataset objects for train, valid, test.
@@ -169,10 +170,7 @@ class Splitter(object):
    if test_dir is None:
      test_dir = tempfile.mkdtemp()
    train_dataset = dataset.select(train_inds, train_dir)
    if frac_valid != 0:
      valid_dataset: Optional[Dataset] = dataset.select(valid_inds, valid_dir)
    else:
      valid_dataset = None
    valid_dataset = dataset.select(valid_inds, valid_dir)
    test_dataset = dataset.select(test_inds, test_dir)
    if isinstance(train_dataset, DiskDataset):
      train_dataset.memory_cache_size = 40 * (1 << 20)  # 40 MB
@@ -274,7 +272,30 @@ class Splitter(object):
    >>> str(dc.splits.RandomSplitter())
    'RandomSplitter'
    """
    return self.__class__.__name__
    args_spec = inspect.getfullargspec(self.__init__)  # type: ignore
    args_names = [arg for arg in args_spec.args if arg != 'self']
    args_num = len(args_names)
    args_default_values = [None for _ in range(args_num)]
    if args_spec.defaults is not None:
      defaults = list(args_spec.defaults)
      args_default_values[-len(defaults):] = defaults

    override_args_info = ''
    for arg_name, default in zip(args_names, args_default_values):
      if arg_name in self.__dict__:
        arg_value = self.__dict__[arg_name]
        # validation
        # skip list
        if isinstance(arg_value, list):
          continue
        if isinstance(arg_value, str):
          # skip path string
          if "\\/." in arg_value or "/" in arg_value or '.' in arg_value:
            continue
        # main logic
        if default != arg_value:
          override_args_info += '_' + arg_name + '_' + str(arg_value)
    return self.__class__.__name__ + override_args_info

  def __repr__(self) -> str:
    """Convert self to repr representation.
@@ -288,9 +309,22 @@ class Splitter(object):
    --------
    >>> import deepchem as dc
    >>> dc.splits.RandomSplitter()
    RandomSplitter
    RandomSplitter[]
    """
    return self.__str__()
    args_spec = inspect.getfullargspec(self.__init__)  # type: ignore
    args_names = [arg for arg in args_spec.args if arg != 'self']
    args_info = ''
    for arg_name in args_names:
      value = self.__dict__[arg_name]
      # for str
      if isinstance(value, str):
        value = "'" + value + "'"
      # for list
      if isinstance(value, list):
        threshold = get_print_threshold()
        value = np.array2string(np.array(value), threshold=threshold)
      args_info += arg_name + '=' + str(value) + ', '
    return self.__class__.__name__ + '[' + args_info[:-2] + ']'


class RandomSplitter(Splitter):
Loading