Commit bd52b89c authored by peastman's avatar peastman
Browse files

Updated API for load_delaney()

parent 9ce7a2a4
Loading
Loading
Loading
Loading
+17 −0
Original line number Diff line number Diff line
@@ -9,12 +9,29 @@ import logging
import json
from typing import Dict, List, Any

import deepchem as dc
from deepchem.feat.base_classes import Featurizer
from deepchem.trans.transformers import Transformer
from deepchem.splits.splitters import Splitter

logger = logging.getLogger(__name__)

featurizers = {
    'ECFP': dc.feat.CircularFingerprint(size=1024),
    'GraphConv': dc.feat.ConvMolFeaturizer(),
    'Weave': dc.feat.WeaveFeaturizer(),
    'Raw': dc.feat.RawFeaturizer()
}

splitters = {
    'index': dc.splits.IndexSplitter(),
    'random': dc.splits.RandomSplitter(),
    'scaffold': dc.splits.ScaffoldSplitter(),
    'butina': dc.splits.ButinaSplitter(),
    'task': dc.splits.TaskSplitter(),
    'stratified': dc.splits.RandomStratifiedSplitter()
}


def get_defaults(module_name: str = None) -> Dict[str, Any]:
  """Get featurizers, transformers, and splitters.
+86 −75
Original line number Diff line number Diff line
@@ -3,21 +3,25 @@ Delaney dataset loader.
"""
import os
import logging
import deepchem
import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from typing import List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

DEFAULT_DIR = deepchem.utils.data_utils.get_data_dir()
DEFAULT_DIR = dc.utils.data_utils.get_data_dir()
DELANEY_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv"


def load_delaney(featurizer='ECFP',
                 split='index',
                 reload=True,
                 move_mean=True,
                 data_dir=None,
                 save_dir=None,
                 **kwargs):
def load_delaney(
    featurizer: Union[dc.feat.Featurizer, str] = 'ECFP',
    splitter: Union[dc.splits.Splitter, str, None] = 'scaffold',
    reload: bool = True,
    move_mean: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
  """Load delaney dataset

  The Delaney(ESOL) dataset a regression dataset containing structures and
@@ -25,7 +29,7 @@ def load_delaney(featurizer='ECFP',
  validate machine learning models on estimating solubility directly from
  molecular structures (as encoded in SMILES strings).

  Random splitting is recommended for this dataset.
  Scaffold splitting is recommended for this dataset.

  The raw data csv file contains columns below:

@@ -34,95 +38,102 @@ def load_delaney(featurizer='ECFP',
  - "measured log solubility in mols per litre" - Log-scale water solubility
    of the compound, used as label

  Parameters
  ----------
  featurizer: Featurizer or str
    the featurizer to use for processing the data.  Alternatively you can pass
    one of the names from dc.molnet.defaults.featurizers as a shortcut.
  splitter: Splitter or str
    the splitter to use for splitting the data into training, validation, and
    test sets.  Alternatively you can pass one of the names from
    dc.molnet.defaults.splitters as a shortcut.  If this is None, all the data
    will be included in a single dataset.
  reload: bool
    if True, the first call for a particular featurizer and splitter will cache
    the datasets to disk, and subsequent calls will reload the cached datasets.
  move_mean: bool
    if True, all the data is shifted so the training set has a mean of zero.
  data_dir: str
    a directory to save the raw data in
  save_dir: str
    a directory to save the dataset in

  References
  ----------
  .. [1] Delaney, John S. "ESOL: estimating aqueous solubility directly from
     molecular structure." Journal of chemical information and computer
     sciences 44.3 (2004): 1000-1005.
  """
  # Featurize Delaney dataset
  logger.info("About to featurize Delaney dataset.")
  if 'split' in kwargs:
    splitter = kwargs['split']
    logger.warning("'split' is deprecated.  Use 'splitter' instead.")
  if isinstance(featurizer, str):
    featurizer = dc.molnet.defaults.featurizers[featurizer]
  if isinstance(splitter, str):
    splitter = dc.molnet.defaults.splitters[splitter]
  if data_dir is None:
    data_dir = DEFAULT_DIR
  if save_dir is None:
    save_dir = DEFAULT_DIR
  tasks = ['measured log solubility in mols per litre']

  # Try to reload cached datasets.

  if reload:
    save_folder = os.path.join(save_dir, "delaney-featurized")
    featurizer_name = str(featurizer.__class__.__name__)
    splitter_name = str(splitter.__class__.__name__)
    if not move_mean:
      save_folder = os.path.join(save_folder, str(featurizer) + "_mean_unmoved")
      featurizer_name = featurizer_name + "_mean_unmoved"
    save_folder = os.path.join(save_dir, "delaney-featurized", featurizer_name,
                               splitter_name)
    if splitter is None:
      if os.path.exists(save_folder):
        transformers = dc.utils.data_utils.load_transformers(save_folder)
        return tasks, (DiskDataset(save_folder),), transformers
    else:
      save_folder = os.path.join(save_folder, str(featurizer))
      loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk(
          save_folder)
      if all_dataset is not None:
        return tasks, all_dataset, transformers

    if featurizer == "smiles2img":
      img_spec = kwargs.get("img_spec", "std")
      save_folder = os.path.join(save_folder, img_spec)
    save_folder = os.path.join(save_folder, str(split))
  # Featurize Delaney dataset

  logger.info("About to featurize Delaney dataset.")
  dataset_file = os.path.join(data_dir, "delaney-processed.csv")

  if not os.path.exists(dataset_file):
    deepchem.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=data_dir)

  delaney_tasks = ['measured log solubility in mols per litre']
  if reload:
    loaded, all_dataset, transformers = deepchem.utils.data_utils.load_dataset_from_disk(
        save_folder)
    if loaded:
      return delaney_tasks, all_dataset, transformers

  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
    featurizer = deepchem.feat.ConvMolFeaturizer()
  elif featurizer == 'Weave':
    featurizer = deepchem.feat.WeaveFeaturizer()
  elif featurizer == 'Raw':
    featurizer = deepchem.feat.RawFeaturizer()
  elif featurizer == "smiles2img":
    img_spec = kwargs.get("img_spec", "std")
    img_size = kwargs.get("img_size", 80)
    res = kwargs.get("res", 0.5)
    featurizer = deepchem.feat.SmilesToImage(
        img_size=img_size, img_spec=img_spec, res=res)

  loader = deepchem.data.CSVLoader(
      tasks=delaney_tasks, feature_field="smiles", featurizer=featurizer)
    dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=data_dir)
  loader = dc.data.CSVLoader(
      tasks=tasks, feature_field="smiles", featurizer=featurizer)
  dataset = loader.create_dataset(dataset_file, shard_size=8192)

  if split is None:
    transformers = [
        deepchem.trans.NormalizationTransformer(
            transform_y=True, dataset=dataset, move_mean=move_mean)
    ]

    logger.info("Split is None, about to transform data")
    for transformer in transformers:
      dataset = transformer.transform(dataset)
  # Split and transform the dataset.

    return delaney_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'stratified': deepchem.splits.SingletaskStratifiedSplitter()
  }
  splitter = splitters[split]
  logger.info("About to split dataset with {} splitter.".format(split))
  if splitter is None:
    transformer_dataset: Dataset = dataset
  else:
    logger.info("About to split dataset with {} splitter.".format(
        splitter.__class__.__name__))
    train, valid, test = splitter.train_valid_test_split(dataset)

    transformer_dataset = train
  transformers = [
      deepchem.trans.NormalizationTransformer(
          transform_y=True, dataset=train, move_mean=move_mean)
      dc.trans.NormalizationTransformer(
          transform_y=True, dataset=transformer_dataset, move_mean=move_mean)
  ]

  logger.info("About to transform data.")
  if splitter is None:
    for transformer in transformers:
      dataset = transformer.transform(dataset)
    if reload and isinstance(dataset, DiskDataset):
      dataset.move(save_folder)
      dc.utils.data_utils.save_transformers(save_folder, transformers)
    return tasks, (dataset,), transformers

  for transformer in transformers:
    train = transformer.transform(train)
    valid = transformer.transform(valid)
    test = transformer.transform(test)

  if reload:
    deepchem.utils.data_utils.save_dataset_to_disk(save_folder, train, valid,
                                                   test, transformers)
  return delaney_tasks, (train, valid, test), transformers
  if reload and isinstance(train, DiskDataset) and isinstance(
      valid, DiskDataset) and isinstance(test, DiskDataset):
    dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test,
                                             transformers)
  return tasks, (train, valid, test), transformers
+12 −16
Original line number Diff line number Diff line
@@ -103,8 +103,7 @@ class Splitter(object):
      train_ds_base = DiskDataset.merge(update_train_base_merge)
    return list(zip(train_datasets, cv_datasets))

  def train_valid_test_split(
      self,
  def train_valid_test_split(self,
                             dataset: Dataset,
                             train_dir: Optional[str] = None,
                             valid_dir: Optional[str] = None,
@@ -114,7 +113,7 @@ class Splitter(object):
                             frac_test: float = 0.1,
                             seed: Optional[int] = None,
                             log_every_n: int = 1000,
      **kwargs) -> Tuple[Dataset, Optional[Dataset], Dataset]:
                             **kwargs) -> Tuple[Dataset, Dataset, Dataset]:
    """ Splits self into train/validation/test sets.

    Returns Dataset objects for train, valid, test.
@@ -169,10 +168,7 @@ class Splitter(object):
    if test_dir is None:
      test_dir = tempfile.mkdtemp()
    train_dataset = dataset.select(train_inds, train_dir)
    if frac_valid != 0:
      valid_dataset: Optional[Dataset] = dataset.select(valid_inds, valid_dir)
    else:
      valid_dataset = None
    valid_dataset = dataset.select(valid_inds, valid_dir)
    test_dataset = dataset.select(test_inds, test_dir)
    if isinstance(train_dataset, DiskDataset):
      train_dataset.memory_cache_size = 40 * (1 << 20)  # 40 MB
+13 −3
Original line number Diff line number Diff line
@@ -520,8 +520,7 @@ def load_dataset_from_disk(save_dir: str) -> Tuple[bool, Optional[Tuple[
  test = dc.data.DiskDataset(test_dir)
  train.memory_cache_size = 40 * (1 << 20)  # 40 MB
  all_dataset = (train, valid, test)
  with open(os.path.join(save_dir, "transformers.pkl"), 'rb') as f:
    transformers = pickle.load(f)
  transformers = load_transformers(save_dir)
  return loaded, all_dataset, transformers


@@ -566,6 +565,17 @@ def save_dataset_to_disk(
  train.move(train_dir)
  valid.move(valid_dir)
  test.move(test_dir)
  save_transformers(save_dir, transformers)


def load_transformers(save_dir: str) -> List["dc.trans.Transformer"]:
  """Load the transformers for a MoleculeNet dataset from disk."""
  with open(os.path.join(save_dir, "transformers.pkl"), 'rb') as f:
    return pickle.load(f)


def save_transformers(save_dir: str,
                      transformers: List["dc.trans.Transformer"]):
  """Save the transformers for a MoleculeNet dataset to disk."""
  with open(os.path.join(save_dir, "transformers.pkl"), 'wb') as f:
    pickle.dump(transformers, f)
  return None