Commit e8eb8b82 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #203 from rbharath/grabbag

Fixes to NCI example, and better handling of large datasets
parents 61bc5bf6 2b01efa5
Loading
Loading
Loading
Loading
−47 KiB (51.6 MiB)

File changed.

No diff preview for this file type.

−46.7 KiB (51.5 MiB)

File changed.

No diff preview for this file type.

+34 −11
Original line number Diff line number Diff line
@@ -18,10 +18,6 @@ __author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "GPL"

# TODO(rbharath): The semantics of this class are very difficult to debug.
# Multiple transformations of the data are performed on disk, and computations
# of mean/std are spread across multiple functions for efficiency. Some
# refactoring needs to happen here.
class Dataset(object):
  """
  Wrapper class for dataset transformed into X, y, w numpy ndarrays.
@@ -140,7 +136,6 @@ class Dataset(object):
            out_X_sums, out_X_sum_squares, out_X_n,
            out_y_sums, out_y_sum_squares, out_y_n]


  def save_to_disk(self):
    """Save dataset to disk."""
    save_to_disk(
@@ -270,7 +265,6 @@ class Dataset(object):
                   metadata_rows=metadata_rows,
                   verbosity=self.verbosity)


  def shuffle(self, iterations=1):
    """Shuffles this dataset on disk to have random order."""
    #np.random.seed(9452)
@@ -331,12 +325,41 @@ class Dataset(object):

  def select(self, select_dir, indices):
    """Creates a new dataset from a selection of indices from self."""
    indices = np.array(indices).astype(int)
    X, y, w, ids = self.to_numpy()
    if not os.path.exists(select_dir):
      os.makedirs(select_dir)
    if not len(indices):
      return Dataset(
          data_dir=select_dir, metadata_row=[], verbosity=self.verbosity)
    indices = np.array(sorted(indices)).astype(int)
    count, indices_count = 0, 0
    metadata_rows = []
    tasks = self.get_task_names()
    X_sel, y_sel, w_sel, ids_sel = (
        X[indices], y[indices], w[indices], ids[indices])
    return Dataset.from_numpy(select_dir, X_sel, y_sel, w_sel, ids_sel, tasks)
    for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
      log("Selecting from shard %d" % shard_num, self.verbosity)
      shard_len = len(X)
      # Find indices which rest in this shard
      num_shard_elts = 0
      while indices[indices_count+num_shard_elts] < count + shard_len:
        num_shard_elts += 1
        if indices_count + num_shard_elts >= len(indices):
          break
      # Need to offset indices to fit within shard_size
      shard_indices = (
          indices[indices_count:indices_count+num_shard_elts] - count)
      X_sel = X[shard_indices]
      y_sel = y[shard_indices]
      w_sel = w[shard_indices]
      ids_sel = ids[shard_indices]
      basename = "dataset-%d" % shard_num
      metadata_rows.append(
          Dataset.write_data_to_disk(select_dir, basename, tasks,
                                     X_sel, y_sel, w_sel, ids_sel))
      # Updating counts
      indices_count += num_shard_elts
      count += shard_len
    return Dataset(data_dir=select_dir,
                   metadata_rows=metadata_rows,
                   verbosity=self.verbosity)
    
  def to_numpy(self):
    """
+10 −9
Original line number Diff line number Diff line
@@ -15,12 +15,11 @@ from deepchem.utils.save import load_sharded_csv
from deepchem.datasets import Dataset
from deepchem.featurizers.featurize import DataFeaturizer
from deepchem.featurizers.fingerprints import CircularFingerprint
from deepchem.transformers import BalancingTransformer
from deepchem.transformers import NormalizationTransformer

def load_nci(base_dir, reload=True):
def load_nci(base_dir, reload=True, force_transform=False):
  """Load NCI datasets. Does not do train/test split"""
  # Set some global variables up top
  #reload = True
  verbosity = "high"
  model = "logistic"
  regen = False
@@ -29,7 +28,7 @@ def load_nci(base_dir, reload=True):
  # The base_dir holds the results of all analysis
  if not reload:
    if os.path.exists(base_dir):
      print("deleting dir in datasets.py")
      print("Deleting dir in nci_datasets.py")
      print(base_dir)
      shutil.rmtree(base_dir)
  if not os.path.exists(base_dir):
@@ -63,23 +62,25 @@ def load_nci(base_dir, reload=True):
                    'OVCAR-3', 'OVCAR-4', 'OVCAR-5', 'OVCAR-8', 'NCI/ADR-RES',
                    'SK-OV-3', '786-0', 'A498', 'ACHN', 'CAKI-1', 'RXF 393',
                    'SN12C', 'TK-10', 'UO-31', 'PC-3', 'DU-145', 'MCF7',
                    'MDA-MB-231/ATCC', 'MDA-MB-468', 'HS 578T', 'MDA-N', 'BT-549'])
                    'MDA-MB-231/ATCC', 'MDA-MB-468', 'HS 578T', 'BT-549',
                    'T-47D'])

  featurizer = DataFeaturizer(tasks=all_nci_tasks,
                              smiles_field="smiles",
                              featurizers=featurizers,
                              verbosity=verbosity)
  if not reload or not os.path.exists(data_dir):
    dataset = featurizer.featurize(dataset_file, data_dir)
    dataset = featurizer.featurize(dataset_paths, data_dir)
    regen = True
  else:
    dataset = Dataset(data_dir, reload=True)

  # Initialize transformers
  transformers = [
      BalancingTransformer(transform_w=True, dataset=dataset)]
  if regen:
  transformers = []
  if regen or force_transform:
    print("About to transform data")
    transformers = [
        NormalizationTransformer(transform_y=True, dataset=dataset)]
    for transformer in transformers:
        transformer.transform(dataset)

+22 −0
Original line number Diff line number Diff line
@@ -48,6 +48,28 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    solubility_dataset = self.load_solubility_data()
    assert len(solubility_dataset) == 10

  def test_select(self):
    """Test that dataset select works."""
    num_datapoints = 10
    num_features = 10
    num_tasks = 1
    X = np.random.rand(num_datapoints, num_features)
    y = np.random.randint(2, size=(num_datapoints, num_tasks))
    w = np.ones((num_datapoints, num_tasks))
    ids = np.array(["id"] * num_datapoints)
    dataset = Dataset.from_numpy(self.data_dir, X, y, w, ids)

    select_dir = tempfile.mkdtemp()
    indices = [0, 4, 5, 8]
    select_dataset = dataset.select(select_dir, indices)
    X_sel, y_sel, w_sel, ids_sel = select_dataset.to_numpy()
    np.testing.assert_array_equal(X[indices], X_sel)
    np.testing.assert_array_equal(y[indices], y_sel)
    np.testing.assert_array_equal(w[indices], w_sel)
    np.testing.assert_array_equal(ids[indices], ids_sel)
    shutil.rmtree(select_dir)
    
  
  def test_iterbatches(self):
    """Test that iterating over batches of data works."""
    solubility_dataset = self.load_solubility_data()
Loading