Commit fdc23975 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixed some of the tests

parent e9760d71
Loading
Loading
Loading
Loading
+26 −62
Original line number Diff line number Diff line
@@ -25,90 +25,54 @@ class TestSplitAPI(unittest.TestCase):
    self.current_dir = os.path.dirname(os.path.abspath(__file__))
    self.test_data_dir = os.path.join(self.current_dir, "../../models/tests")
    self.smiles_field = "smiles"
    self.feature_dir = tempfile.mkdtemp()
    self.samples_dir = tempfile.mkdtemp()
    self.data_dir = tempfile.mkdtemp()
    self.train_dir = tempfile.mkdtemp()
    self.valid_dir = tempfile.mkdtemp()
    self.test_dir = tempfile.mkdtemp()

  def tearDown(self):
    shutil.rmtree(self.feature_dir)
    shutil.rmtree(self.samples_dir)
    shutil.rmtree(self.data_dir)
    shutil.rmtree(self.train_dir)
    shutil.rmtree(self.valid_dir)
    shutil.rmtree(self.test_dir)

  def _gen_samples(self, compound_featurizers, complex_featurizers,
                   input_transformer_classes, output_transformer_classes,
                   input_file, tasks,
                   protein_pdb_field=None, ligand_pdb_field=None,
                   user_specified_features=None,
                   split_field=None,
                   shard_size=100):
    # Featurize input
    featurizers = compound_featurizers + complex_featurizers

    input_file = os.path.join(self.test_data_dir, input_file)
  def load_solubility_data(self):
    """Loads solubility data from example.csv"""
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["log-solubility"]
    task_type = "regression"
    input_file = os.path.join(self.test_data_dir, "example.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        protein_pdb_field=protein_pdb_field,
        ligand_pdb_field=ligand_pdb_field,
        compound_featurizers=compound_featurizers,
        complex_featurizers=complex_featurizers,
        user_specified_features=user_specified_features,
        split_field=split_field,
        featurizers=featurizers,
        verbosity="low")

    samples = featurizer.featurize(input_file, self.feature_dir, self.samples_dir,
                                   shard_size=shard_size)
    return samples
    return featurizer.featurize(input_file, self.data_dir)

  def _load_solubility_samples(self):
    """Loads solubility data from example.csv"""
    compound_featurizers = [CircularFingerprint(size=1024)]
    complex_featurizers = []
    input_transformer_classes = []
    output_transformer_classes = []
    tasks = ["log-solubility"]
    task_type = "regression"
    task_types = {task: task_type for task in tasks}
    input_file = "example.csv"
    return self._gen_samples(
        compound_featurizers, complex_featurizers,
        input_transformer_classes, output_transformer_classes,
        input_file, tasks)

  def _load_classification_samples(self):
  def load_classification_data(self):
    """Loads classification data from example.csv"""
    compound_featurizers = [CircularFingerprint(size=1024)]
    complex_featurizers = []
    input_transformer_classes = []
    output_transformer_classes = []
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["outcome"]
    task_type = "classification"
    task_types = {task: task_type for task in tasks}
    input_file = "example_classification.csv"
    return self._gen_samples(
        compound_featurizers, complex_featurizers,
        input_transformer_classes, output_transformer_classes,
        input_file, tasks)
    input_file = os.path.join(self.test_data_dir, "example_classification.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")
    return featurizer.featurize(input_file, self.data_dir)

  def _load_multitask_samples(self):
  def load_multitask_data(self):
    """Load example multitask data."""
    compound_featurizers = [CircularFingerprint(size=1024)]
    complex_featurizers = []
    output_transformer_classes = []
    input_transformer_classes = []
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["task0", "task1", "task2", "task3", "task4", "task5", "task6",
             "task7", "task8", "task9", "task10", "task11", "task12",
             "task13", "task14", "task15", "task16"]
    task_types = {task: "classification" for task in tasks}
    input_file = "multitask_example.csv"
    return self._gen_samples(
        compound_featurizers, complex_featurizers,
        input_transformer_classes, output_transformer_classes,
        input_file, tasks)
    input_file = os.path.join(self.test_data_dir, "multitask_example.csv")
    featurizer = DataFeaturizer(
        tasks=tasks,
        smiles_field=self.smiles_field,
        featurizers=featurizers,
        verbosity="low")
    return featurizer.featurize(input_file, self.data_dir)