Commit 00d80849 authored by marta-sd's avatar marta-sd
Browse files

added sanitize flag

parent 22798d00
Loading
Loading
Loading
Loading
+36 −16
Original line number Original line Diff line number Diff line
@@ -910,6 +910,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
               voxel_width=1.0,
               voxel_width=1.0,
               flatten=False,
               flatten=False,
               verbose=True,
               verbose=True,
               sanitize=False,
               **kwargs):
               **kwargs):
    """Parameters:
    """Parameters:
    -----------
    -----------
@@ -967,6 +968,10 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        'box_x', 'box_y', 'box_z', 'save_intermediates', 'voxelize_features',
        'box_x', 'box_y', 'box_z', 'save_intermediates', 'voxelize_features',
        'parallel', 'voxel_feature_types'
        'parallel', 'voxel_feature_types'
    ]
    ]

    # list of features that require sanitized molecules
    require_sanitized = ['pi_stack', 'cation_pi']

    for arg in deprecated_args:
    for arg in deprecated_args:
      if arg in kwargs and verbose:
      if arg in kwargs and verbose:
        warn('%s argument was removed and it is ignored,'
        warn('%s argument was removed and it is ignored,'
@@ -974,6 +979,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
             DeprecationWarning)
             DeprecationWarning)


    self.verbose = verbose
    self.verbose = verbose
    self.sanitize = sanitize
    self.flatten = flatten
    self.flatten = flatten


    self.ecfp_degree = ecfp_degree
    self.ecfp_degree = ecfp_degree
@@ -1222,8 +1228,20 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    # each entry is a tuple (is_flat, feature_name)
    # each entry is a tuple (is_flat, feature_name)
    self.feature_types = []
    self.feature_types = []


    # list of features that cannot be calculated with specified parameters
    # this list is used to define <flat/voxel/all>_combined subset
    ignored_features = []
    if self.sanitize is False:
      ignored_features += require_sanitized

    # parse provided feature types
    # parse provided feature types
    for feature_type in feature_types:
    for feature_type in feature_types:
      if self.sanitize is False and feature_type in require_sanitized:
        if self.verbose:
          warn('sanitize is set to False, %s feature will be ignored' %
               feature_type)
        continue

      if feature_type in self.FLAT_FEATURES:
      if feature_type in self.FLAT_FEATURES:
        self.feature_types.append((True, feature_type))
        self.feature_types.append((True, feature_type))
        if self.flatten is False:
        if self.flatten is False:
@@ -1235,28 +1253,28 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        self.feature_types.append((False, feature_type))
        self.feature_types.append((False, feature_type))


      elif feature_type == 'flat_combined':
      elif feature_type == 'flat_combined':
        self.feature_types += list(
        self.feature_types += [(True, ftype)
            zip([True] * len(self.FLAT_FEATURES),
                               for ftype in sorted(self.FLAT_FEATURES.keys())
                sorted(self.FLAT_FEATURES.keys())))
                               if ftype not in ignored_features]
        if self.flatten is False:
        if self.flatten is False:
          if self.verbose:
          if self.verbose:
            warn('flat features are used, output will be flattened')
            warn('Flat features are used, output will be flattened')
          self.flatten = True
          self.flatten = True


      elif feature_type == 'voxel_combined':
      elif feature_type == 'voxel_combined':
        self.feature_types += list(
        self.feature_types += [(False, ftype)
            zip([False] * len(self.VOXEL_FEATURES),
                               for ftype in sorted(self.VOXEL_FEATURES.keys())
                sorted(self.VOXEL_FEATURES.keys())))
                               if ftype not in ignored_features]
      elif feature_type == 'all_combined':
      elif feature_type == 'all_combined':
        self.feature_types += list(
        self.feature_types += [(True, ftype)
            zip([True] * len(self.FLAT_FEATURES),
                               for ftype in sorted(self.FLAT_FEATURES.keys())
                sorted(self.FLAT_FEATURES.keys())))
                               if ftype not in ignored_features]
        self.feature_types += list(
        self.feature_types += [(False, ftype)
            zip([False] * len(self.VOXEL_FEATURES),
                               for ftype in sorted(self.VOXEL_FEATURES.keys())
                sorted(self.VOXEL_FEATURES.keys())))
                               if ftype not in ignored_features]
        if self.flatten is False:
        if self.flatten is False:
          if self.verbose:
          if self.verbose:
            warn('flat feature are used, output will be flattened')
            warn('Flat feature are used, output will be flattened')
          self.flatten = True
          self.flatten = True
      elif self.verbose:
      elif self.verbose:
        warn('Ignoring unknown feature %s' % feature_type)
        warn('Ignoring unknown feature %s' % feature_type)
@@ -1333,7 +1351,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    ############################################################## TIMING
    ############################################################## TIMING


    if not self.ligand_only:
    if not self.ligand_only:
      protein_xyz, protein_rdk = load_molecule(protein_pdb, calc_charges=True)
      protein_xyz, protein_rdk = load_molecule(
          protein_pdb, calc_charges=True, sanitize=self.sanitize)
    ############################################################## TIMING
    ############################################################## TIMING
    time2 = time.time()
    time2 = time.time()
    log("TIMING: Loading protein coordinates took %0.3f s" % (time2 - time1),
    log("TIMING: Loading protein coordinates took %0.3f s" % (time2 - time1),
@@ -1342,7 +1361,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    ############################################################## TIMING
    ############################################################## TIMING
    time1 = time.time()
    time1 = time.time()
    ############################################################## TIMING
    ############################################################## TIMING
    ligand_xyz, ligand_rdk = load_molecule(ligand_file, calc_charges=True)
    ligand_xyz, ligand_rdk = load_molecule(
        ligand_file, calc_charges=True, sanitize=self.sanitize)
    ############################################################## TIMING
    ############################################################## TIMING
    time2 = time.time()
    time2 = time.time()
    log("TIMING: Loading ligand coordinates took %0.3f s" % (time2 - time1),
    log("TIMING: Loading ligand coordinates took %0.3f s" % (time2 - time1),
+10 −10
Original line number Original line Diff line number Diff line
@@ -8,7 +8,7 @@ import unittest
import numpy as np
import numpy as np
np.random.seed(123)
np.random.seed(123)


from rdkit.Chem import MolFromMolFile, MolFromSmiles, MolFromPDBFile, SDMolSupplier
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.AllChem import Mol, ComputeGasteigerCharges
from rdkit.Chem.AllChem import Mol, ComputeGasteigerCharges


from deepchem.feat import rdkit_grid_featurizer as rgf
from deepchem.feat import rdkit_grid_featurizer as rgf
@@ -205,15 +205,13 @@ class TestPiInteractions(unittest.TestCase):
    self.cycle4.Compute2DCoords()
    self.cycle4.Compute2DCoords()


    # load and sanitize two real molecules
    # load and sanitize two real molecules
    self.prot = MolFromPDBFile(
    _, self.prot = rgf.load_molecule(
        os.path.join(current_dir, '3ws9_protein_fixer_rdkit.pdb'),
        os.path.join(current_dir, '3ws9_protein_fixer_rdkit.pdb'),
        sanitize=True,
        add_hydrogens=False, calc_charges=False, sanitize=True)
        removeHs=False)


    self.lig = SDMolSupplier(
    _, self.lig = rgf.load_molecule(
        os.path.join(current_dir, '3ws9_ligand.sdf'),
        os.path.join(current_dir, '3ws9_ligand.sdf'),
        sanitize=True,
        add_hydrogens=False, calc_charges=False, sanitize=True)
        removeHs=False)[0]


  def test_compute_ring_center(self):
  def test_compute_ring_center(self):
    # FIXME might break with different version of rdkit
    # FIXME might break with different version of rdkit
@@ -334,7 +332,7 @@ class TestFeaturizationFunctions(unittest.TestCase):
    self.ligand_file = os.path.join(current_dir, '3ws9_ligand.sdf')
    self.ligand_file = os.path.join(current_dir, '3ws9_ligand.sdf')


  def test_compute_all_ecfp(self):
  def test_compute_all_ecfp(self):
    mol = MolFromMolFile(self.ligand_file)
    _, mol = rgf.load_molecule(self.ligand_file)
    num_atoms = mol.GetNumAtoms()
    num_atoms = mol.GetNumAtoms()
    for degree in range(1, 4):
    for degree in range(1, 4):
      # TODO test if dict contains smiles
      # TODO test if dict contains smiles
@@ -482,7 +480,8 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
        ],
        ],
        ecfp_power=ecfp_power,
        ecfp_power=ecfp_power,
        splif_power=splif_power,
        splif_power=splif_power,
        flatten=True)
        flatten=True,
        sanitize=True)
    self.assertIsInstance(featurizer, rgf.RdkitGridFeaturizer)
    self.assertIsInstance(featurizer, rgf.RdkitGridFeaturizer)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
                                                    [self.protein_file])
@@ -510,7 +509,8 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
        box_width=box_w,
        box_width=box_w,
        ecfp_power=f_power,
        ecfp_power=f_power,
        feature_types=['all_combined'],
        feature_types=['all_combined'],
        flatten=True)
        flatten=True,
        sanitize=True)


    prot_tensor = rgf_featurizer._voxelize(
    prot_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,
        rgf.convert_atom_to_voxel,
+7 −2
Original line number Original line Diff line number Diff line
@@ -87,7 +87,8 @@ def compute_charges(mol):
  return mol
  return mol




def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True):
def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True,
                  sanitize=False):
  """
  """
  Converts molecule file to (xyz-coords, obmol object)
  Converts molecule file to (xyz-coords, obmol object)


@@ -99,7 +100,7 @@ def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True):
  :return: (xyz, mol)
  :return: (xyz, mol)
  """
  """
  if ".mol2" in molecule_file:
  if ".mol2" in molecule_file:
    my_mol = Chem.MolFromMol2File(molecule_file)
    my_mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
  elif ".sdf" in molecule_file:
  elif ".sdf" in molecule_file:
    suppl = Chem.SDMolSupplier(str(molecule_file), sanitize=False)
    suppl = Chem.SDMolSupplier(str(molecule_file), sanitize=False)
    my_mol = suppl[0]
    my_mol = suppl[0]
@@ -118,6 +119,10 @@ def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True):


  if add_hydrogens or calc_charges:
  if add_hydrogens or calc_charges:
    my_mol = add_hydrogens_to_mol(my_mol)
    my_mol = add_hydrogens_to_mol(my_mol)
  # TODO: mol should be always sanitized when charges are calculated
  # can't change it now because it would break a lot of examples
  if sanitize:
    Chem.SanitizeMol(my_mol)
  if calc_charges:
  if calc_charges:
    compute_charges(my_mol)
    compute_charges(my_mol)