Commit 9bed010b authored by miaecle's avatar miaecle
Browse files

Merge remote-tracking branch 'remotes/origin/master' into momlnet

parents a85d57f1 d77ccf9b
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ from deepchem.feat.basic import RDKitDescriptors
from deepchem.feat.coulomb_matrices import CoulombMatrix
from deepchem.feat.coulomb_matrices import CoulombMatrixEig
from deepchem.feat.grid_featurizer import GridFeaturizer
from deepchem.feat.rdkit_grid_featurizer import RdkitGridFeaturizer
from deepchem.feat.nnscore_utils import hydrogenate_and_compute_partial_charges
from deepchem.feat.binding_pocket_features import BindingPocketFeaturizer
from deepchem.feat.one_hot import OneHotFeaturizer
+1213 −0

File added.

Preview size limit exceeded, changes collapsed.

+1 −0
Original line number Diff line number Diff line
# coding=utf-8
"""
Contains an abstract base class that supports data transformations.
"""
+125 −0
Original line number Diff line number Diff line
import logging

import os
import numpy as np
import tempfile
import shutil
from rdkit import Chem
from rdkit.Chem import AllChem
from pdbfixer import PDBFixer
from simtk.openmm.app import PDBFile

try:
  from StringIO import StringIO
except ImportError:
  from io import StringIO


class MoleculeLoadException(Exception):

  def __init__(self, *args, **kwargs):
    Exception.__init__(*args, **kwargs)


def get_xyz_from_mol(mol):
  """
  returns an m x 3 np array of 3d coords
  of given rdkit molecule
  """
  xyz = np.zeros((mol.GetNumAtoms(), 3))
  conf = mol.GetConformer()
  for i in range(conf.GetNumAtoms()):
    position = conf.GetAtomPosition(i)
    xyz[i, 0] = position.x
    xyz[i, 1] = position.y
    xyz[i, 2] = position.z
  return (xyz)


def add_hydrogens_to_mol(mol):
  molecule_file = None
  try:
    pdbblock = Chem.MolToPDBBlock(mol)
    pdb_stringio = StringIO()
    pdb_stringio.write(pdbblock)
    pdb_stringio.seek(0)
    fixer = PDBFixer(pdbfile=pdb_stringio)
    fixer.addMissingHydrogens(7.4)

    hydrogenated_io = StringIO()
    PDBFile.writeFile(fixer.topology, fixer.positions, hydrogenated_io)
    hydrogenated_io.seek(0)
    return Chem.MolFromPDBBlock(
        hydrogenated_io.read(), sanitize=False, removeHs=False)
  except ValueError as e:
    logging.warning("Unable to add hydrogens", e)
    raise MoleculeLoadException(e)
  finally:
    try:
      os.remove(molecule_file)
    except (OSError, TypeError):
      pass


def compute_charges(mol):
  try:
    AllChem.ComputeGasteigerCharges(mol)
  except Exception as e:
    logging.exception("Unable to compute charges for mol")
    raise MoleculeLoadException(e)
  return mol


def load_molecule(molecule_file, add_hydrogens=True, calc_charges=True):
  """Converts molecule file to (xyz-coords, obmol object)

  Given molecule_file, returns a tuple of xyz coords of molecule
  and an rdkit object representing that molecule
  """
  if ".mol2" in molecule_file or ".sdf" in molecule_file:
    suppl = Chem.SDMolSupplier(str(molecule_file), sanitize=False)
    my_mol = suppl[0]
  elif ".pdbqt" in molecule_file:
    raise MoleculeLoadException("Don't support pdbqt files yet")
  elif ".pdb" in molecule_file:
    my_mol = Chem.MolFromPDBFile(
        str(molecule_file), sanitize=False, removeHs=False)
  else:
    raise ValueError("Unrecognized file type")

  if my_mol is None:
    raise ValueError("Unable to read non None Molecule Object")

  if add_hydrogens:
    my_mol = add_hydrogens_to_mol(my_mol)
  if calc_charges:
    compute_charges(my_mol)

  xyz = get_xyz_from_mol(my_mol)

  return xyz, my_mol


def write_molecule(mol, outfile):
  if ".pdbqt" in outfile:
    # TODO (LESWING) create writer for pdbqt which includes charges
    writer = Chem.PDBWriter(outfile)
    writer.write(mol)
    writer.close()
    pass
  elif ".pdb" in outfile:
    writer = Chem.PDBWriter(outfile)
    writer.write(mol)
    writer.close()
  else:
    raise ValueError("Unsupported Format")


def pdbqt_to_pdb(filename):
  base_filename = os.path.splitext(filename)[0]
  pdb_filename = base_filename + ".pdb"
  pdbqt_data = open(filename).readlines()
  with open(pdb_filename, 'w') as fout:
    for line in pdbqt_data:
      fout.write("%s\n" % line[:66])
  return pdb_filename
+41 −26
Original line number Diff line number Diff line
@@ -8,7 +8,9 @@ BENCHMARK_TO_DESIRED_KEY_MAP = {
    "logreg": "logistic regression",
    "tf": "Multitask network",
    "tf_robust": "robust MT-NN",
    "tf_regression": "NN regression",
    "graphconv": "graph convolution",
    "graphconvreg": "graphconv regression",
}
DESIRED_RESULTS_CSV = "devtools/jenkins/desired_results.csv"
TEST_RESULTS_CSV = "examples/results.csv"
@@ -53,12 +55,18 @@ def find_desired_result(result, desired_results):


def is_good_result(my_result, desired_result):
  retval = True
  message = []
  for key in ['train_score', 'test_score']:
    # Higher is Better
    desired_value = desired_result[key] * (1.0 - CUSHION_PERCENT)
    if my_result[key] < desired_value:
      return False
  return True
      message_part = "%s,%s,%s,%s,%s,%s" % (
          my_result['data_set'], my_result['model'], my_result['split'], key,
          my_result[key], desired_result[key])
      message.append(message_part)
      retval = False
  return retval, message


def test_compare_results():
@@ -66,15 +74,22 @@ def test_compare_results():
  desired_results = parse_desired_results(desired_results)
  test_results = open(TEST_RESULTS_CSV).readlines()
  test_results = parse_test_results(test_results)
  failures = []
  exceptions = []
  for test_result in test_results:
    try:
      desired_result = find_desired_result(test_result, desired_results)
    if not is_good_result(test_result, desired_result):
      exceptions.append(({"test_result": test_result}, {"desired_result": desired_result}))
  if len(exceptions) > 0:
      passes, message = is_good_result(test_result, desired_result)
      if not passes:
        failures.extend(message)
    except Exception as e:
      exceptions.append("Unable to find desired result for %s" % test_result)
  for exception in exceptions:
    print(exception)
    assert_true(len(exceptions) == 0, "Some performance benchmarks not passed")
  for failure in failures:
    print(failure)
  assert_true(len(exceptions) == 0, "Error parsing performance results")
  assert_true(len(failures) == 0, "Some performance benchmarks not passed")

  if __name__ == "__main__":
    test_compare_results()
Loading