Commit d0aa21b6 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Partial fix for data split issues

parent fdc23975
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -114,7 +114,6 @@ class DataFeaturizer(object):
    self.mol_field = mol_field
    self.user_specified_features = user_specified_features
    self.featurizers = featurizers
    #self.complex_featurizers = complex_featurizers
    self.log_every_n = log_every_n

  def _load_sdf_file(self, input_file):
+34 −39
Original line number Diff line number Diff line
@@ -14,7 +14,6 @@ import numpy as np
from rdkit import Chem
from deepchem.utils import ScaffoldGenerator
from deepchem.utils.save import log
#from deepchem.featurizers.featurize import FeaturizedSamples
from deepchem.datasets import Dataset

def generate_scaffold(smiles, include_chirality=False):
@@ -39,9 +38,6 @@ class Splitter(object):
      if given_dir is None:
        continue
        
      # TODO(rbharath): This is uncomfortably tied to the internal
      # implementation of FeaturizedSamples. Disentangle Splitter and
      # FeaturizedSamples in a future refactoring.
      compounds_filename = os.path.join(given_dir, "datasets.joblib")
      if not os.path.exists(compounds_filename):
        return False
@@ -55,7 +51,7 @@ class Splitter(object):
    """
    Splits self into train/validation/test sets.

    Returns FeaturizedDataset objects.
    Returns Dataset objects.
    """
    compute_split = (
        not reload
@@ -70,7 +66,7 @@ class Splitter(object):
    dataset_files = samples.dataset_files

    # Generate train dir
    train_samples = FeaturizedSamples(samples_dir=train_dir, 
    train_samples = Dataset(samples_dir=train_dir, 
                            dataset_files=dataset_files,
                            featurizers=samples.featurizers,
                            verbosity=self.verbosity,
@@ -78,7 +74,7 @@ class Splitter(object):
    if compute_split:
      train_samples._set_compound_df(samples.compounds_df.iloc[train_inds])
    # Generate test dir
    test_samples = FeaturizedSamples(samples_dir=test_dir, 
    test_samples = Dataset(samples_dir=test_dir, 
                           dataset_files=dataset_files,
                           featurizers=samples.featurizers,
                           verbosity=self.verbosity,
@@ -87,7 +83,7 @@ class Splitter(object):
      test_samples._set_compound_df(samples.compounds_df.iloc[test_inds])
    # if requested, generated valid_dir
    if valid_dir is not None:
      valid_samples = FeaturizedSamples(samples_dir=valid_dir, 
      valid_samples = Dataset(samples_dir=valid_dir, 
                              dataset_files=dataset_files,
                              featurizers=samples.featurizers,
                              verbosity=self.verbosity,
@@ -102,7 +98,7 @@ class Splitter(object):
    """
    Splits self into train/test sets.

    Returns FeaturizedDataset objects.
    Returns Dataset objects.
    """
    valid_dir = None
    train_samples, _, test_samples = self.train_valid_test_split(
@@ -122,7 +118,7 @@ class MolecularWeightSplitter(Splitter):
  """
  Class for doing data splits by molecular weight.
  """
  def split(self, samples, seed=None, frac_train=.8, frac_valid=.1,
  def split(self, dataset, seed=None, frac_train=.8, frac_valid=.1,
            frac_test=.1, log_every_n=None):
    """
    Splits internal compounds into train/validation/test using the MW calculated
@@ -132,10 +128,9 @@ class MolecularWeightSplitter(Splitter):
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    np.random.seed(seed)

    df = samples.compounds_df
    mws = []
    for _, row in df.iterrows():
        mol = Chem.MolFromSmiles(row['smiles'])
    for smiles in dataset.get_ids():
      mol = Chem.MolFromSmiles(smiles)
      mw = Chem.rdMolDescriptors.CalcExactMolWt(mol)
      mws.append(mw)

@@ -153,16 +148,16 @@ class RandomSplitter(Splitter):
  """
  Class for doing random data splits.
  """
  def split(self, samples, seed=None, frac_train=.8, frac_valid=.1,
  def split(self, dataset, seed=None, frac_train=.8, frac_valid=.1,
            frac_test=.1, log_every_n=None):
    """
    Splits internal compounds randomly into train/validation/test.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    np.random.seed(seed)
    train_cutoff = frac_train * len(samples.compounds_df)
    valid_cutoff = (frac_train+frac_valid) * len(samples.compounds_df)
    shuffled = np.random.permutation(range(len(samples.compounds_df)))
    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train+frac_valid) * len(dataset)
    shuffled = np.random.permutation(range(len(dataset)))
    return (shuffled[:train_cutoff], shuffled[train_cutoff:valid_cutoff],
            shuffled[valid_cutoff:])

@@ -170,7 +165,7 @@ class ScaffoldSplitter(Splitter):
  """
  Class for doing data splits based on the scaffold of small molecules.
  """
  def split(self, samples, frac_train=.8, frac_valid=.1, frac_test=.1,
  def split(self, dataset, frac_train=.8, frac_valid=.1, frac_test=.1,
            log_every_n=1000):
    """
    Splits internal compounds into train/validation/test by scaffold.
@@ -178,11 +173,11 @@ class ScaffoldSplitter(Splitter):
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    scaffolds = {}
    log("About to generate scaffolds", self.verbosity)
    for ind, row in samples.compounds_df.iterrows():
    data_len = len(dataset)
    for smiles in dataset.get_ids():
      if self.verbosity is not None and ind % log_every_n == 0:
        log("Generating scaffold %d/%d" % (ind, len(samples.compounds_df)),
            self.verbosity)
      scaffold = generate_scaffold(row["smiles"])
        log("Generating scaffold %d/%d" % (ind, data_len), self.verbosity)
      scaffold = generate_scaffold(smiles)
      if scaffold not in scaffolds:
        scaffolds[scaffold] = [ind]
      else:
@@ -190,8 +185,8 @@ class ScaffoldSplitter(Splitter):
    # Sort from largest to smallest scaffold sets
    scaffold_sets = [scaffold_set for (scaffold, scaffold_set) in
                     sorted(scaffolds.items(), key=lambda x: -len(x[1]))]
    train_cutoff = frac_train * len(samples.compounds_df)
    valid_cutoff = (frac_train+frac_valid) * len(samples.compounds_df)
    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train+frac_valid) * len(dataset)
    train_inds, valid_inds, test_inds = [], [], []
    log("About to sort in scaffold sets", self.verbosity)
    for scaffold_set in scaffold_sets:
@@ -208,7 +203,7 @@ class SpecifiedSplitter(Splitter):
  """
  Class that splits data according to user specification.
  """
  def split(self, samples, frac_train=.8, frac_valid=.1, frac_test=.1,
  def split(self, dataset, frac_train=.8, frac_valid=.1, frac_test=.1,
            log_every_n=1000):
    """
    Splits internal compounds into train/validation/test by user-specification.