Commit e5916873 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixing broken tests

parent 1eb99723
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -15,9 +15,9 @@ from deepchem.utils.save import load_sharded_csv
from deepchem.datasets import Dataset
from deepchem.featurizers.featurize import DataFeaturizer
from deepchem.featurizers.fingerprints import CircularFingerprint
from deepchem.transformers import BalancingTransformer
from deepchem.transformers import NormalizationTransformer

def load_nci(base_dir, reload=True):
def load_nci(base_dir, reload=True, force_transform=False):
  """Load NCI datasets. Does not do train/test split"""
  # Set some global variables up top
  verbosity = "high"
@@ -76,10 +76,10 @@ def load_nci(base_dir, reload=True):
    dataset = Dataset(data_dir, reload=True)

  # Initialize transformers
  transformers = [
      BalancingTransformer(transform_w=True, dataset=dataset)]
  if regen:
  if regen or force_transform:
    print("About to transform data")
    transformers = [
        NormalizationTransformer(transform_y=True, dataset=dataset)]
    for transformer in transformers:
        transformer.transform(dataset)

+2 −4
Original line number Diff line number Diff line
@@ -44,10 +44,8 @@ class RDKitDescriptors(Featurizer):

  def __init__(self):
    self.descriptors = []
    self.functions = []
    for descriptor, function in Descriptors.descList:
      self.descriptors.append(descriptor)
      self.functions.append(function)

  def _featurize(self, mol):
    """
@@ -59,6 +57,6 @@ class RDKitDescriptors(Featurizer):
        Molecule.
    """
    rval = []
    for function in self.functions:
    for _, function in Descriptors.descList:
      rval.append(function(mol))
    return rval
+1 −1
Original line number Diff line number Diff line
@@ -168,7 +168,7 @@ class SpecifiedSplitter(Splitter):

  def __init__(self, input_file, split_field, verbosity=None):
    """Provide input information for splits."""
    raw_df = load_data(input_file, shard_size=None).next()
    raw_df = load_data([input_file], shard_size=None).next()
    self.splits = raw_df[split_field].values
    self.verbosity = verbosity

+1 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import pickle
import pandas as pd
import numpy as np
import os
from rdkit import Chem

def log(string, verbosity=None, level="low"):
  """Print string if verbose."""
+3 −2
Original line number Diff line number Diff line
@@ -22,14 +22,15 @@ np.random.seed(123)

# Set some global variables up top

reload = False
reload = True
verbosity = "high"
force_transform = False 

base_data_dir = "/scratch/users/rbharath/nci_data_dir"
base_dir = "/scratch/users/rbharath/nci_analysis_dir"

nci_tasks, dataset, transformers = load_nci(
    base_data_dir, reload=reload)
    base_data_dir, reload=reload, force_transform=force_transform)

if os.path.exists(base_dir):
  shutil.rmtree(base_dir)