Unverified Commit e57fa410 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2325 from peastman/fingerprint

Created FingerprintSplitter
parents b804376c bd9012ec
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -77,6 +77,7 @@ splitters = {
    'random': dc.splits.RandomSplitter(),
    'scaffold': dc.splits.ScaffoldSplitter(),
    'butina': dc.splits.ButinaSplitter(),
    'fingerprint': dc.splits.FingerprintSplitter(),
    'task': dc.splits.TaskSplitter(),
    'stratified': dc.splits.RandomStratifiedSplitter()
}
+148 −127
Original line number Diff line number Diff line
@@ -1061,14 +1061,14 @@ class ButinaSplitter(Splitter):
    super(ButinaSplitter, self).__init__()
    self.cutoff = cutoff

  def split(
      self,
  def split(self,
            dataset: Dataset,
            frac_train: float = 0.8,
            frac_valid: float = 0.1,
            frac_test: float = 0.1,
            seed: Optional[int] = None,
      log_every_n: Optional[int] = None) -> Tuple[List[int], List[int], List]:
            log_every_n: Optional[int] = None
           ) -> Tuple[List[int], List[int], List[int]]:
    """
    Splits internal compounds into train and validation based on the butina
    clustering algorithm. This splitting algorithm has an O(N^2) run time, where N
@@ -1084,11 +1084,11 @@ class ButinaSplitter(Splitter):
    dataset: Dataset
      Dataset to be split.
    frac_train: float, optional (default 0.8)
      The fraction of data to be used for the training split (not currently used).
      The fraction of data to be used for the training split.
    frac_valid: float, optional (default 0.1)
      The fraction of data to be used for the validation split (not currently used).
      The fraction of data to be used for the validation split.
    frac_test: float, optional (default 0.1)
      The fraction of data to be used for the test split (not currently used).
      The fraction of data to be used for the test split.
    seed: int, optional (default None)
      Random seed to use.
    log_every_n: int, optional (default None)
@@ -1098,7 +1098,6 @@ class ButinaSplitter(Splitter):
    -------
    Tuple[List[int], List[int], List[int]]
      A tuple of train indices, valid indices, and test indices.
      Each indices is a list of integers and test indices is always an empty list.
    """
    try:
      from rdkit import Chem, DataStructs
@@ -1181,6 +1180,143 @@ def _generate_scaffold(smiles: str, include_chirality: bool = False) -> str:
  return scaffold


class FingerprintSplitter(Splitter):
  """Class for doing data splits based on the Tanimoto similarity between ECFP4
  fingerprints.

  This class tries to split the data such that the molecules in each dataset are
  as different as possible from the ones in the other datasets.  This makes it a
  very stringent test of models.  Predicting the test and validation sets may
  require extrapolating far outside the training data.

  The running time for this splitter scales as O(n^2) in the number of samples.
  Splitting large datasets can take a long time.

  Note
  ----
  This class requires RDKit to be installed.
  """

  def __init__(self):
    """Create a FingerprintSplitter."""
    super(FingerprintSplitter, self).__init__()

  def split(self,
            dataset: Dataset,
            frac_train: float = 0.8,
            frac_valid: float = 0.1,
            frac_test: float = 0.1,
            seed: Optional[int] = None,
            log_every_n: Optional[int] = None
           ) -> Tuple[List[int], List[int], List[int]]:
    """
    Splits compounds into training, validation, and test sets based on the
    Tanimoto similarity of their ECFP4 fingerprints. This splitting algorithm
    has an O(N^2) run time, where N is the number of elements in the dataset.

    Parameters
    ----------
    dataset: Dataset
      Dataset to be split.
    frac_train: float, optional (default 0.8)
      The fraction of data to be used for the training split.
    frac_valid: float, optional (default 0.1)
      The fraction of data to be used for the validation split.
    frac_test: float, optional (default 0.1)
      The fraction of data to be used for the test split.
    seed: int, optional (default None)
      Random seed to use (ignored since this algorithm is deterministic).
    log_every_n: int, optional (default None)
      Log every n examples (not currently used).

    Returns
    -------
    Tuple[List[int], List[int], List[int]]
      A tuple of train indices, valid indices, and test indices.
    """
    try:
      from rdkit import Chem
      from rdkit.Chem import AllChem
    except ModuleNotFoundError:
      raise ImportError("This function requires RDKit to be installed.")

    # Compute fingerprints for all molecules.

    mols = [Chem.MolFromSmiles(smiles) for smiles in dataset.ids]
    fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in mols]

    # Split into two groups: test training set and everything else.

    train_size = int(frac_train * len(dataset))
    valid_size = int(frac_valid * len(dataset))
    test_size = len(dataset) - train_size - valid_size
    train_inds, test_valid_inds = _split_fingerprints(fps, train_size,
                                                      valid_size + test_size)

    # Split the second group into validation and test sets.

    if valid_size == 0:
      valid_inds = []
      test_inds = test_valid_inds
    elif test_size == 0:
      test_inds = []
      valid_inds = test_valid_inds
    else:
      test_valid_fps = [fps[i] for i in test_valid_inds]
      test_inds, valid_inds = _split_fingerprints(test_valid_fps, test_size,
                                                  valid_size)
      test_inds = [test_valid_inds[i] for i in test_inds]
      valid_inds = [test_valid_inds[i] for i in valid_inds]
    return train_inds, valid_inds, test_inds


def _split_fingerprints(fps: List, size1: int,
                        size2: int) -> Tuple[List[int], List[int]]:
  """This is called by FingerprintSplitter to divide a list of fingerprints into
  two groups.
  """
  assert len(fps) == size1 + size2
  from rdkit import DataStructs

  # Begin by assigning the first molecule to the first group.

  fp_in_group = [[fps[0]], []]
  indices_in_group: Tuple[List[int], List[int]] = ([0], [])
  remaining_fp = fps[1:]
  remaining_indices = list(range(1, len(fps)))
  max_similarity_to_group = [
      DataStructs.BulkTanimotoSimilarity(fps[0], remaining_fp),
      [0] * len(remaining_fp)
  ]
  while len(remaining_fp) > 0:
    # Decide which group to assign a molecule to.

    group = 0 if len(fp_in_group[0]) / size1 <= len(
        fp_in_group[1]) / size2 else 1

    # Identify the unassigned molecule that is least similar to everything in
    # the other group.

    i = np.argmin(max_similarity_to_group[1 - group])

    # Add it to the group.

    fp = remaining_fp[i]
    fp_in_group[group].append(fp)
    indices_in_group[group].append(remaining_indices[i])

    # Update the data on unassigned molecules.

    similarity = DataStructs.BulkTanimotoSimilarity(fp, remaining_fp)
    max_similarity_to_group[group] = np.delete(
        np.maximum(similarity, max_similarity_to_group[group]), i)
    max_similarity_to_group[1 - group] = np.delete(
        max_similarity_to_group[1 - group], i)
    del remaining_fp[i]
    del remaining_indices[i]
  return indices_in_group


class ScaffoldSplitter(Splitter):
  """Class for doing data splits based on the scaffold of small molecules.

@@ -1281,121 +1417,6 @@ class ScaffoldSplitter(Splitter):
    return scaffold_sets


class FingerprintSplitter(Splitter):
  """Class for doing data splits based on the fingerprints of small
  molecules O(N**2) algorithm.

  Note
  ----
  This class requires RDKit to be installed.
  """

  def split(self,
            dataset: Dataset,
            frac_train: float = 0.8,
            frac_valid: float = 0.1,
            frac_test: float = 0.1,
            seed: Optional[int] = None,
            log_every_n: Optional[int] = None
           ) -> Tuple[List[int], List[int], List[int]]:
    """
    Splits internal compounds into train/validation/test by fingerprint.

    Parameters
    ----------
    dataset: Dataset
      Dataset to be split.
    frac_train: float, optional (default 0.8)
      The fraction of data to be used for the training split.
    frac_valid: float, optional (default 0.1)
      The fraction of data to be used for the validation split.
    frac_test: float, optional (default 0.1)
      The fraction of data to be used for the test split.
    seed: int, optional (default None)
      Random seed to use.
    log_every_n: int, optional (default None)
      Log every n examples (not currently used).

    Returns
    -------
    Tuple[List[int], List[int], List[int]]
      A tuple of train indices, valid indices, and test indices.
      Each indices is a list of integers.
    """
    try:
      from rdkit import Chem, DataStructs
      from rdkit.Chem.Fingerprints import FingerprintMols
    except ModuleNotFoundError:
      raise ImportError("This function requires RDKit to be installed.")

    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    data_len = len(dataset)
    mols, fingerprints = [], []
    train_inds, valid_inds, test_inds = [], [], []
    for ind, smiles in enumerate(dataset.ids):
      mol = Chem.MolFromSmiles(smiles, sanitize=False)
      mols.append(mol)
      fp = FingerprintMols.FingerprintMol(mol)
      fingerprints.append(fp)

    distances = np.ones(shape=(data_len, data_len))
    for i in range(data_len):
      for j in range(data_len):
        distances[i][j] = 1 - DataStructs.FingerprintSimilarity(
            fingerprints[i], fingerprints[j])

    train_cutoff = int(frac_train * len(dataset))
    valid_cutoff = int(frac_valid * len(dataset))

    # Pick the mol closest to everything as the first element of training
    closest_ligand = np.argmin(np.sum(distances, axis=1))
    train_inds.append(closest_ligand)
    cur_distances = [float('inf')] * data_len
    self.update_distances(closest_ligand, cur_distances, distances, train_inds)
    for i in range(1, train_cutoff):
      closest_ligand = np.argmin(cur_distances)
      train_inds.append(closest_ligand)
      self.update_distances(closest_ligand, cur_distances, distances,
                            train_inds)

    # Pick the closest mol from what is left
    index, best_dist = 0, float('inf')
    for i in range(data_len):
      if i in train_inds:
        continue
      dist = np.sum(distances[i])
      if dist < best_dist:
        index, best_dist = i, dist
    valid_inds.append(index)

    leave_out_indexes = train_inds + valid_inds
    cur_distances = [float('inf')] * data_len
    self.update_distances(index, cur_distances, distances, leave_out_indexes)
    for i in range(1, valid_cutoff):
      closest_ligand = np.argmin(cur_distances)
      valid_inds.append(closest_ligand)
      leave_out_indexes.append(closest_ligand)
      self.update_distances(closest_ligand, cur_distances, distances,
                            leave_out_indexes)

    # Test is everything else
    for i in range(data_len):
      if i in leave_out_indexes:
        continue
      test_inds.append(i)
    return train_inds, valid_inds, test_inds

  def update_distances(self, last_selected, cur_distances, distance_matrix,
                       dont_update):
    for i in range(len(cur_distances)):
      if i in dont_update:
        cur_distances[i] = float('inf')
        continue
      new_dist = distance_matrix[i][last_selected]
      if new_dist < cur_distances[i]:
        cur_distances[i] = new_dist


#################################################################
# Not well supported splitters
#################################################################
+13 −0
Original line number Diff line number Diff line
@@ -581,3 +581,16 @@ class TestSplitter(unittest.TestCase):
    assert not np.array_equal(train1.X, train2.X)
    assert not np.array_equal(valid1.X, valid2.X)
    assert not np.array_equal(test1.X, test2.X)

  def test_fingerprint_split(self):
    """
    Test FingerprintSplitter.
    """
    multitask_dataset = load_multitask_data()
    splitter = dc.splits.FingerprintSplitter()
    train_data, valid_data, test_data = \
      splitter.train_valid_test_split(
        multitask_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1