Commit 00d80849 authored by marta-sd's avatar marta-sd
Browse files

added sanitize flag

parent 22798d00
Loading
Loading
Loading
Loading
+36 −16
Original line number Diff line number Diff line
@@ -910,6 +910,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
               voxel_width=1.0,
               flatten=False,
               verbose=True,
               sanitize=False,
               **kwargs):
    """Parameters:
    -----------
@@ -967,6 +968,10 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        'box_x', 'box_y', 'box_z', 'save_intermediates', 'voxelize_features',
        'parallel', 'voxel_feature_types'
    ]

    # list of features that require sanitized molecules
    require_sanitized = ['pi_stack', 'cation_pi']

    for arg in deprecated_args:
      if arg in kwargs and verbose:
        warn('%s argument was removed and it is ignored,'
@@ -974,6 +979,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
             DeprecationWarning)

    self.verbose = verbose
    self.sanitize = sanitize
    self.flatten = flatten

    self.ecfp_degree = ecfp_degree
@@ -1222,8 +1228,20 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    # each entry is a tuple (is_flat, feature_name)
    self.feature_types = []

    # list of features that cannot be calculated with specified parameters
    # this list is used to define <flat/voxel/all>_combined subset
    ignored_features = []
    if self.sanitize is False:
      ignored_features += require_sanitized

    # parse provided feature types
    for feature_type in feature_types:
      if self.sanitize is False and feature_type in require_sanitized:
        if self.verbose:
          warn('sanitize is set to False, %s feature will be ignored' %
               feature_type)
        continue

      if feature_type in self.FLAT_FEATURES:
        self.feature_types.append((True, feature_type))
        if self.flatten is False:
@@ -1235,28 +1253,28 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        self.feature_types.append((False, feature_type))

      elif feature_type == 'flat_combined':
        self.feature_types += list(
            zip([True] * len(self.FLAT_FEATURES),
                sorted(self.FLAT_FEATURES.keys())))
        self.feature_types += [(True, ftype)
                               for ftype in sorted(self.FLAT_FEATURES.keys())
                               if ftype not in ignored_features]
        if self.flatten is False:
          if self.verbose:
            warn('flat features are used, output will be flattened')
            warn('Flat features are used, output will be flattened')
          self.flatten = True

      elif feature_type == 'voxel_combined':
        self.feature_types += list(
            zip([False] * len(self.VOXEL_FEATURES),
                sorted(self.VOXEL_FEATURES.keys())))
        self.feature_types += [(False, ftype)
                               for ftype in sorted(self.VOXEL_FEATURES.keys())
                               if ftype not in ignored_features]
      elif feature_type == 'all_combined':
        self.feature_types += list(
            zip([True] * len(self.FLAT_FEATURES),
                sorted(self.FLAT_FEATURES.keys())))
        self.feature_types += list(
            zip([False] * len(self.VOXEL_FEATURES),
                sorted(self.VOXEL_FEATURES.keys())))
        self.feature_types += [(True, ftype)
                               for ftype in sorted(self.FLAT_FEATURES.keys())
                               if ftype not in ignored_features]
        self.feature_types += [(False, ftype)
                               for ftype in sorted(self.VOXEL_FEATURES.keys())
                               if ftype not in ignored_features]
        if self.flatten is False:
          if self.verbose:
            warn('flat feature are used, output will be flattened')
            warn('Flat feature are used, output will be flattened')
          self.flatten = True
      elif self.verbose:
        warn('Ignoring unknown feature %s' % feature_type)
@@ -1333,7 +1351,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    ############################################################## TIMING

    if not self.ligand_only:
      protein_xyz, protein_rdk = load_molecule(protein_pdb, calc_charges=True)
      protein_xyz, protein_rdk = load_molecule(
          protein_pdb, calc_charges=True, sanitize=self.sanitize)
    ############################################################## TIMING
    time2 = time.time()
    log("TIMING: Loading protein coordinates took %0.3f s" % (time2 - time1),
@@ -1342,7 +1361,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    ############################################################## TIMING
    time1 = time.time()
    ############################################################## TIMING
    ligand_xyz, ligand_rdk = load_molecule(ligand_file, calc_charges=True)
    ligand_xyz, ligand_rdk = load_molecule(
        ligand_file, calc_charges=True, sanitize=self.sanitize)
    ############################################################## TIMING
    time2 = time.time()
    log("TIMING: Loading ligand coordinates took %0.3f s" % (time2 - time1),
+10 −10
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ import unittest
import numpy as np
np.random.seed(123)

from rdkit.Chem import MolFromMolFile, MolFromSmiles, MolFromPDBFile, SDMolSupplier
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.AllChem import Mol, ComputeGasteigerCharges

from deepchem.feat import rdkit_grid_featurizer as rgf
@@ -205,15 +205,13 @@ class TestPiInteractions(unittest.TestCase):
    self.cycle4.Compute2DCoords()

    # load and sanitize two real molecules
    self.prot = MolFromPDBFile(
    _, self.prot = rgf.load_molecule(
        os.path.join(current_dir, '3ws9_protein_fixer_rdkit.pdb'),
        sanitize=True,
        removeHs=False)
        add_hydrogens=False, calc_charges=False, sanitize=True)

    self.lig = SDMolSupplier(
    _, self.lig = rgf.load_molecule(
        os.path.join(current_dir, '3ws9_ligand.sdf'),
        sanitize=True,
        removeHs=False)[0]
        add_hydrogens=False, calc_charges=False, sanitize=True)

  def test_compute_ring_center(self):
    # FIXME might break with different version of rdkit
@@ -334,7 +332,7 @@ class TestFeaturizationFunctions(unittest.TestCase):
    self.ligand_file = os.path.join(current_dir, '3ws9_ligand.sdf')

  def test_compute_all_ecfp(self):
    mol = MolFromMolFile(self.ligand_file)
    _, mol = rgf.load_molecule(self.ligand_file)
    num_atoms = mol.GetNumAtoms()
    for degree in range(1, 4):
      # TODO test if dict contains smiles
@@ -482,7 +480,8 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
        ],
        ecfp_power=ecfp_power,
        splif_power=splif_power,
        flatten=True)
        flatten=True,
        sanitize=True)
    self.assertIsInstance(featurizer, rgf.RdkitGridFeaturizer)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
@@ -510,7 +509,8 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
        box_width=box_w,
        ecfp_power=f_power,
        feature_types=['all_combined'],
        flatten=True)
        flatten=True,
        sanitize=True)

    prot_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,
+7 −2
Original line number Diff line number Diff line
@@ -87,7 +87,8 @@ def compute_charges(mol):
  return mol


def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True):
def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True,
                  sanitize=False):
  """
  Converts molecule file to (xyz-coords, obmol object)

@@ -99,7 +100,7 @@ def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True):
  :return: (xyz, mol)
  """
  if ".mol2" in molecule_file:
    my_mol = Chem.MolFromMol2File(molecule_file)
    my_mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
  elif ".sdf" in molecule_file:
    suppl = Chem.SDMolSupplier(str(molecule_file), sanitize=False)
    my_mol = suppl[0]
@@ -118,6 +119,10 @@ def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True):

  if add_hydrogens or calc_charges:
    my_mol = add_hydrogens_to_mol(my_mol)
  # TODO: mol should be always sanitized when charges are calculated
  # can't change it now because it would break a lot of examples
  if sanitize:
    Chem.SanitizeMol(my_mol)
  if calc_charges:
    compute_charges(my_mol)