Commit aa6e3b12 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Yet more refactoring

parent d0aa21b6
Loading
Loading
Loading
Loading
+29 −59
Original line number Diff line number Diff line
@@ -20,78 +20,48 @@ from deepchem.featurizers.fingerprints import CircularFingerprint
from deepchem.transformers import NormalizationTransformer
from deepchem.splits.tests import TestSplitAPI

class TestDatasetAPI(TestSplitAPI):
class TestDatasetAPI(TestAPI):
  """
  Shared API for testing with dataset objects. 
  """
  # TODO(rbharath): There should be a more natural way to create a dataset
  # object, perhaps just starting from (Xs, ys, ws)
  def _create_dataset(self, compound_featurizers, complex_featurizers,
                      input_transformer_classes, output_transformer_classes,
                      input_file, tasks,
                      protein_pdb_field=None, ligand_pdb_field=None,
                      user_specified_features=None,
                      split_field=None,
                      shard_size=100):
    featurizers = compound_featurizers + complex_featurizers
    samples = self._gen_samples(
        compound_featurizers, complex_featurizers,
        input_transformer_classes, output_transformer_classes,
        input_file, tasks,
        protein_pdb_field=protein_pdb_field,
        ligand_pdb_field=ligand_pdb_field,
        user_specified_features=user_specified_features,
        split_field=split_field,
        shard_size=shard_size)
    use_user_specified_features = (user_specified_features is not None)
    dataset = Dataset(data_dir=self.data_dir, samples=samples, 
                      featurizers=featurizers, tasks=tasks,
                      use_user_specified_features=use_user_specified_features)
    return dataset

  def _load_solubility_data(self):
  def load_solubility_data(self):
    """Loads solubility data from example.csv"""
    compound_featurizers = [CircularFingerprint(size=1024)]
    complex_featurizers = []
    input_transformer_classes = []
    output_transformer_classes = []
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["log-solubility"]
    task_type = "regression"
    task_types = {task: task_type for task in tasks}
    input_file = "example.csv"
    return self._create_dataset(
        compound_featurizers, complex_featurizers,
        input_transformer_classes, output_transformer_classes,
        input_file, tasks)
    input_file = os.path.join(self.test_data_dir, "example.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")

  def _load_classification_data(self):
    return featurizer.featurize(input_file, self.data_dir)

  def load_classification_data(self):
    """Loads classification data from example.csv"""
    compound_featurizers = [CircularFingerprint(size=1024)]
    complex_featurizers = []
    input_transformer_classes = []
    output_transformer_classes = []
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["outcome"]
    task_type = "classification"
    task_types = {task: task_type for task in tasks}
    input_file = "example_classification.csv"
    return self._create_dataset(
        compound_featurizers, complex_featurizers,
        input_transformer_classes, output_transformer_classes,
        input_file, tasks)
    input_file = os.path.join(self.test_data_dir, "example_classification.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")
    return featurizer.featurize(input_file, self.data_dir)

  def _load_multitask_data(self):
  def load_multitask_data(self):
    """Load example multitask data."""
    compound_featurizers = [CircularFingerprint(size=1024)]
    complex_featurizers = []
    output_transformer_classes = []
    input_transformer_classes = []
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["task0", "task1", "task2", "task3", "task4", "task5", "task6",
             "task7", "task8", "task9", "task10", "task11", "task12",
             "task13", "task14", "task15", "task16"]
    task_types = {task: "classification" for task in tasks}
    input_file = "multitask_example.csv"
    return self._create_dataset(
        compound_featurizers, complex_featurizers,
        input_transformer_classes, output_transformer_classes,
        input_file, tasks)
    input_file = os.path.join(self.test_data_dir, "multitask_example.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")
    return featurizer.featurize(input_file, self.data_dir)
+9 −10
Original line number Diff line number Diff line
@@ -26,32 +26,31 @@ class TestBasicDatasetAPI(TestDatasetAPI):
  """
  def test_get_task_names(self):
    """Test that get_task_names returns correct task_names"""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    assert solubility_dataset.get_task_names() == ["log-solubility"]

    multitask_dataset = self._load_multitask_data()
    multitask_dataset = self.load_multitask_data()
    assert sorted(multitask_dataset.get_task_names()) == sorted(["task0",
        "task1", "task2", "task3", "task4", "task5", "task6", "task7", "task8",
        "task9", "task10", "task11", "task12", "task13", "task14", "task15",
        "task16"])


  def test_get_data_shape(self):
    """Test that get_data_shape returns currect data shape"""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    assert solubility_dataset.get_data_shape() == (1024,) 
    
    multitask_dataset = self._load_multitask_data()
    multitask_dataset = self.load_multitask_data()
    assert multitask_dataset.get_data_shape() == (1024,)

  def test_len(self):
    """Test that len(dataset) works."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    assert len(solubility_dataset) == 10
  
  def test_iterbatches(self):
    """Test that iterating over batches of data works."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    batch_size = 2
    data_shape = solubility_dataset.get_data_shape()
    tasks = solubility_dataset.get_task_names()
@@ -63,7 +62,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):

  def test_to_numpy(self):
    """Test that transformation to numpy arrays is sensible."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    data_shape = solubility_dataset.get_data_shape()
    tasks = solubility_dataset.get_task_names()
    X, y, w, ids = solubility_dataset.to_numpy()
@@ -77,7 +76,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):

  def test_consistent_ordering(self):
    """Test that ordering of labels is consistent over time."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()

    ids1 = solubility_dataset.get_ids()
    ids2 = solubility_dataset.get_ids()
@@ -86,7 +85,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):

  def test_get_statistics(self):
    """Test statistics computation of this dataset."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    X, y, _, _ = solubility_dataset.to_numpy()
    X_means, y_means = np.mean(X, axis=0), np.mean(y, axis=0)
    X_stds, y_stds = np.std(X, axis=0), np.std(y, axis=0)
+0 −5
Original line number Diff line number Diff line
@@ -21,11 +21,6 @@ from deepchem.utils.save import load_pandas_from_disk
from deepchem.featurizers import Featurizer, ComplexFeaturizer
from deepchem.datasets import Dataset

def _check_validity(compounds_df):
  """Ensure that columns of compound_df contain required elements."""
  if not set(FeaturizedSamples.colnames).issubset(compounds_df.keys()):
    raise ValueError("Compound dataframe does not contain required columns")

def _process_field(val):
  """Parse data in a field."""
  if (isinstance(val, numbers.Number) or isinstance(val, np.ndarray)):
+3 −20
Original line number Diff line number Diff line
@@ -32,19 +32,7 @@ class Splitter(object):
    """Creates splitter object."""
    self.verbosity = verbosity

  def _check_populated(self, sample_dirs):
    """Check that the provided sample directories are valid."""
    for given_dir in sample_dirs:
      if given_dir is None:
        continue
        
      compounds_filename = os.path.join(given_dir, "datasets.joblib")
      if not os.path.exists(compounds_filename):
        return False
    return True


  def train_valid_test_split(self, samples, train_dir,
  def train_valid_test_split(self, datset, train_dir,
                             valid_dir, test_dir, frac_train=.8,
                             frac_valid=.1, frac_test=.1, seed=None,
                             log_every_n=1000, reload=False):
@@ -53,17 +41,12 @@ class Splitter(object):

    Returns Dataset objects.
    """
    compute_split = (
        not reload
        or not self._check_populated([train_dir, test_dir, valid_dir]))
    if compute_split:
    if not reload:
      log("Computing train/valid/test indices", self.verbosity)
      train_inds, valid_inds, test_inds = self.split(
          samples,
          dataset,
          frac_train=frac_train, frac_test=frac_test,
          frac_valid=frac_valid, log_every_n=log_every_n)
    train_samples, valid_samples, test_samples = None, None, None
    dataset_files = samples.dataset_files

    # Generate train dir
    train_samples = Dataset(samples_dir=train_dir, 
+0 −78
Original line number Diff line number Diff line
"""
General API for testing splitter objects
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "GPL"

import os
import shutil
import tempfile
import unittest
from deepchem.featurizers.featurize import DataFeaturizer
from deepchem.featurizers.fingerprints import CircularFingerprint

class TestSplitAPI(unittest.TestCase):
  """
  Test top-level API for Splitter objects.
  """

  def setUp(self):
    self.current_dir = os.path.dirname(os.path.abspath(__file__))
    self.test_data_dir = os.path.join(self.current_dir, "../../models/tests")
    self.smiles_field = "smiles"
    self.data_dir = tempfile.mkdtemp()
    self.train_dir = tempfile.mkdtemp()
    self.valid_dir = tempfile.mkdtemp()
    self.test_dir = tempfile.mkdtemp()

  def tearDown(self):
    shutil.rmtree(self.data_dir)
    shutil.rmtree(self.train_dir)
    shutil.rmtree(self.valid_dir)
    shutil.rmtree(self.test_dir)

  def load_solubility_data(self):
    """Loads solubility data from example.csv"""
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["log-solubility"]
    task_type = "regression"
    input_file = os.path.join(self.test_data_dir, "example.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")

    return featurizer.featurize(input_file, self.data_dir)

  def load_classification_data(self):
    """Loads classification data from example.csv"""
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["outcome"]
    task_type = "classification"
    input_file = os.path.join(self.test_data_dir, "example_classification.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")
    return featurizer.featurize(input_file, self.data_dir)

  def load_multitask_data(self):
    """Load example multitask data."""
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["task0", "task1", "task2", "task3", "task4", "task5", "task6",
             "task7", "task8", "task9", "task10", "task11", "task12",
             "task13", "task14", "task15", "task16"]
    input_file = os.path.join(self.test_data_dir, "multitask_example.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")
    return featurizer.featurize(input_file, self.data_dir)