Commit d6606e8d authored by Nathan Frey's avatar Nathan Frey
Browse files

handle complex feat failures

parent 542fafd0
Loading
Loading
Loading
Loading
+17 −6
Original line number Diff line number Diff line
@@ -174,16 +174,27 @@ class ComplexFeaturizer(Featurizer):

    if not isinstance(complexes, Iterable):
      complexes = [cast(Tuple[str, str], complexes)]
    features = []
    for i, point in enumerate(complexes):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)
    features, failures = [], []
    for idx, point in enumerate(complexes):
      if idx % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % idx)
      try:
        features.append(self._featurize(point))
      except:
        logger.warning(
            "Failed to featurize datapoint %i. Appending empty array." % i)
        features.append(np.array([]))
            "Failed to featurize datapoint %i. Appending empty array." % idx)
        features.append(np.zeros(1))
        failures.append(idx)

    # Find a successful featurization
    i = np.argmax([f.shape[0] for f in features])
    dtype = features[i].dtype
    shape = features[i].shape
    dummy_array = np.zeros(shape, dtype=dtype)

    # Replace failed featurizations with appropriate array
    for idx in failures:
      features[idx] = dummy_array

    return np.asarray(features)

+1 −3
Original line number Diff line number Diff line
@@ -449,7 +449,5 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        else:
          features_dict[system_id] = np.concatenate(feature_arrays, axis=-1)

    features = np.array(list(features_dict.values()))
    if self.nb_rotations == 0:  # squeeze out axis with dimension 1
      features = np.squeeze(features, axis=0)
    features = np.concatenate(list(features_dict.values()))
    return features
+69 −6
Original line number Diff line number Diff line
@@ -36,7 +36,8 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
  def test_example_featurizer(self):
    # check if use-case from examples works
    featurizer = RdkitGridFeaturizer(
        voxel_width=16.0,
        voxel_width=1.0,
        box_width=75.0,
        feature_types=['ecfp', 'splif', 'hbond', 'salt_bridge'],
        ecfp_power=9,
        splif_power=9,
@@ -85,6 +86,7 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
    # test flat features
    featurizer = RdkitGridFeaturizer(
        voxel_width=1.0,
        box_width=75.0,
        feature_types=['flat_combined'],
        ecfp_power=ecfp_power,
        splif_power=splif_power,
@@ -100,7 +102,8 @@ class TestRdkitGridFeaturizer(unittest.TestCase):

    # check if aromatic features are ignored if sanitize=False
    featurizer = RdkitGridFeaturizer(
        voxel_width=16.0,
        voxel_width=1.0,
        box_width=75.0,
        feature_types=['all_combined'],
        ecfp_power=ecfp_power,
        splif_power=splif_power,
@@ -112,8 +115,7 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
    feature_tensor = featurizer.featurize([(self.ligand_file,
                                            self.protein_file)])
    self.assertIsInstance(feature_tensor, np.ndarray)
    total_len = voxel_total_len + flat_total_len - 3 - 2**ecfp_power
    self.assertEqual(feature_tensor.shape, (1, total_len))
    self.assertEqual(feature_tensor.shape, (1, 56109538))

  def test_custom_cutoffs(self):
    custom_cutoffs = {
@@ -134,11 +136,72 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
  def test_rotations(self):
    featurizer = RdkitGridFeaturizer(
        nb_rotations=3,
        box_width=16.,
        box_width=75.,
        voxel_width=1.,
        feature_types=['voxel_combined'],
        flatten=False,
        sanitize=True)
    feature_tensors = featurizer.featurize([(self.ligand_file,
                                             self.protein_file)])
    self.assertEqual(feature_tensors.shape, (1, 4, 16, 16, 16, 40))
    self.assertEqual(feature_tensors.shape, (1, 300, 75, 75, 40))

    featurizer = RdkitGridFeaturizer(
        nb_rotations=3,
        box_width=75.,
        voxel_width=1.,
        feature_types=['flat_combined'],
        flatten=True,
        sanitize=True)
    feature_tensors = featurizer.featurize([(self.ligand_file,
                                             self.protein_file)])
    self.assertEqual(feature_tensors.shape, (1, 204))

  def test_failures(self):
    # test flattened voxel features
    featurizer = RdkitGridFeaturizer(
        nb_rotations=0,
        box_width=75.,
        voxel_width=1.,
        feature_types=['voxel_combined'],
        flatten=True,
        sanitize=True)

    features = featurizer.featurize([(self.ligand_file, self.protein_file),
                                     ('nan', 'nan')])
    self.assertEqual(features.shape, (2, 16875000))

    # test voxel features
    featurizer = RdkitGridFeaturizer(
        nb_rotations=0,
        box_width=75.,
        voxel_width=1.,
        feature_types=['voxel_combined'],
        flatten=False,
        sanitize=True)
    features = featurizer.featurize([(self.ligand_file, self.protein_file),
                                     ('nan', 'nan')])
    self.assertEqual(features.shape, (2, 75, 75, 75, 40))

    # test flat features
    featurizer = RdkitGridFeaturizer(
        nb_rotations=0,
        box_width=75.,
        voxel_width=1.,
        feature_types=['flat_combined'],
        flatten=True,
        sanitize=True)
    features = featurizer.featurize([(self.ligand_file, self.protein_file),
                                     ('nan', 'nan')])
    self.assertEqual(features.shape, (2, 51))

    # test rotations
    featurizer = RdkitGridFeaturizer(
        nb_rotations=5,
        box_width=75.,
        voxel_width=1.,
        feature_types=['flat_combined'],
        flatten=True,
        sanitize=True)
    features = featurizer.featurize([(self.ligand_file, self.protein_file),
                                     ('nan', 'nan')])
    self.assertEqual(features.shape, (2, 306))