Commit 4ed84466 authored by evanfeinberg's avatar evanfeinberg
Browse files

added nnscore featurization and nosetest

parent 9ce844d5
Loading
Loading
Loading
Loading
+93 −0
Original line number Diff line number Diff line
"""
Contains methods for generating a pdbbind dataset mapping
  complexes (protein + ligand) to experimental binding measurement.
"""
import pickle
import os
import pandas as pd
from rdkit import Chem
from glob import glob
import re


def extract_labels(pdbbind_label_file):
  """Extract labels from pdbbind label file."""
  assert os.path.isfile(pdbbind_label_file)
  labels = {}
  with open(pdbbind_label_file) as f:
    content = f.readlines()
    for line in content:
      if line[0] == "#":
        continue
      line = line.split()
      # lines in the label file have format
      # PDB-code Resolution Release-Year -logKd Kd reference ligand-name
      #print line[0], line[3]
      labels[line[0]] = line[3]
  return labels

def construct_df(pdb_stem_directory, pdbbind_label_file, pdbbind_df_pkl):
  """
  Takes as input a stem directory containing subdirectories with ligand
    and protein pdb/mol2 files, a pdbbind_label_file containing binding
    assay data for the co-crystallized ligand in each pdb file,
    and a pdbbind_df_pkl to which will be saved a pandas DataFrame
    where each row contains a pdb_id, smiles string, unique complex id,
    ligand pdb as a list of strings per line in file, protein pdb as a list
    of strings per line in file, ligand mol2 as a list of strings per line in
    mol2 file, and a "label" containing the experimental measurement.
  """
  labels = extract_labels(pdbbind_label_file)
  df_rows = []
  os.chdir(pdb_stem_directory)
  pdb_directories = [pdb.replace('/', '') for pdb in glob('*/')]

  for pdb_dir in pdb_directories:
    print "About to extract ligand and protein input files"
    pdb_id = os.path.basename(pdb_dir)
    ligand_pdb = None
    protein_pdb = None
    for f in os.listdir(pdb_dir):
      if re.search("_ligand_hyd.pdb$", f):
        ligand_pdb = f
      elif re.search("_protein_hyd.pdb$", f):
        protein_pdb = f
      elif re.search("_ligand.mol2$", f):
        ligand_mol2 = f

    print "Extracted Input Files:"
    print (ligand_pdb, protein_pdb, ligand_mol2)
    if not ligand_pdb or not protein_pdb or not ligand_mol2:
      raise ValueError("Required files not present for %s" % pdb_dir)
    ligand_pdb_path = os.path.join(pdb_dir, ligand_pdb)
    protein_pdb_path = os.path.join(pdb_dir, protein_pdb)
    ligand_mol2_path = os.path.join(pdb_dir, ligand_mol2)

    with open(protein_pdb_path, "rb") as f:
      protein_pdb_lines = f.readlines()

    with open(ligand_pdb_path, "rb") as f:
      ligand_pdb_lines = f.readlines()

    try:
      with open(ligand_mol2_path, "rb") as f:
        ligand_mol2_lines = f.readlines()
    except:
      ligand_mol2_lines = []

    print "About to compute ligand smiles string."
    ligand_mol = Chem.MolFromPDBFile(ligand_pdb_path)
    if ligand_mol is None:
      continue
    smiles = Chem.MolToSmiles(ligand_mol)
    complex_id = "%s%s" % (pdb_id, smiles)
    label = labels[pdb_id]
    df_rows.append([pdb_id, smiles, complex_id, protein_pdb_lines,
                    ligand_pdb_lines, ligand_mol2_lines, label])

  pdbbind_df = pd.DataFrame(df_rows, columns=('pdb_id', 'smiles', 'complex_id',
                                              'protein_pdb', 'ligand_pdb',
                                              'ligand_mol2', 'label'))

  with open(pdbbind_df_pkl, "wb") as f:
    pickle.dump(pdbbind_df, f)
+4 −1
Original line number Diff line number Diff line
@@ -160,7 +160,10 @@ class Model(object):
      y_preds = []
      for j in range(len(interval_points)-1):
        indices = range(interval_points[j], interval_points[j+1])
        y_preds.append(self.predict_on_batch(X[indices, :]))
        y_pred_on_batch = self.predict_on_batch(X[indices, :])
        y_pred_on_batch = np.reshape(y_pred_on_batch, (len(indices),))
        y_preds.append(y_pred_on_batch)

      y_pred = np.concatenate(y_preds)
      y_pred = np.reshape(y_pred, np.shape(y))

+4 −0
Original line number Diff line number Diff line
@@ -134,6 +134,8 @@ class MultiTaskDNN(KerasModel):
    """
    data = self.get_data_dict(X)
    y_pred_dict = self.raw_model.predict_on_batch(data)
    print("y_pred_dict.keys()")
    print(y_pred_dict.keys())
    sorted_tasks = sorted(self.task_types.keys())
    nb_samples = np.shape(X)[0]
    nb_tasks = len(sorted_tasks)
@@ -141,6 +143,8 @@ class MultiTaskDNN(KerasModel):
    for ind, task in enumerate(sorted_tasks):
      task_type = self.task_types[task]
      taskname = "task%d" % ind
      print("taskname")
      print(taskname)
      if task_type == "classification":
        # Class probabilities are predicted for classification outputs. Instead,
        # output the most likely class.
+23 −5
Original line number Diff line number Diff line
@@ -44,6 +44,15 @@ def add_featurize_group(featurize_cmd):
  featurize_group.add_argument(
      "--threshold", type=float, default=None,
      help="If specified, will be used to binarize real-valued target-fields.")
  featurize_group.add_argument(
      "--protein-pdb-field", type=str, default=None,
      help="Name of field holding protein pdb.")
  featurize_group.add_argument(
      "--ligand-pdb-field", type=str, default=None,
      help="Name of field holding ligand pdb.")
  featurize_group.add_argument(
      "--ligand-mol2-field", type=str, default=None,
      help="Name of field holding ligand mol2.")
  featurize_group.add_argument(
      "--parallel", type=float, default=None,
      help="Use multiprocessing will be used to parallelize featurization.")
@@ -62,7 +71,7 @@ def add_transforms_group(cmd):
           "to mean no transforms are required.")
  transform_group.add_argument(
      "--feature-types", nargs="+", required=1,
      choices=["user-specified-features", "ECFP", "RDKIT-descriptors"],
      choices=["user-specified-features", "ECFP", "RDKIT-descriptors", "NNScore"],
      help="Featurizations of data to use.\n"
           "'features' denotes user-defined features.\n"
           "'fingerprints' denotes ECFP fingeprints.\n"
@@ -205,7 +214,8 @@ def create_model(args):
    featurize_inputs(
        feature_dir, data_dir, args.input_files, args.user_specified_features,
        args.tasks, args.smiles_field, args.split_field, args.id_field,
        args.threshold, args.parallel)
        args.threshold, args.protein_pdb_field,
        args.ligand_pdb_field, args.ligand_mol2_field, args.parallel)

  if args.generate_dataset:
    print("+++++++++++++++++++++++++++++++++")
@@ -275,7 +285,8 @@ def parse_args(input_args=None):

def featurize_inputs(feature_dir, data_dir, input_files,
                     user_specified_features, tasks, smiles_field,
                     split_field, id_field, threshold, parallel):
                     split_field, id_field, threshold, protein_pdb_field, 
                     ligand_pdb_field, ligand_mol2_field, parallel):

  """Allows for parallel data featurization."""
  featurize_input_partial = partial(featurize_input,
@@ -285,7 +296,10 @@ def featurize_inputs(feature_dir, data_dir, input_files,
                                    smiles_field=smiles_field,
                                    split_field=split_field,
                                    id_field=id_field,
                                    threshold=threshold)
                                    threshold=threshold,
                                    protein_pdb_field=protein_pdb_field,
                                    ligand_pdb_field=ligand_pdb_field,
                                    ligand_mol2_field=ligand_mol2_field)

  if parallel:
    pool = mp.Pool(int(mp.cpu_count()/2))
@@ -302,13 +316,17 @@ def featurize_inputs(feature_dir, data_dir, input_files,
  FeaturizedSamples(samples_dir, dataset_files)

def featurize_input(input_file, feature_dir, user_specified_features, tasks,
                    smiles_field, split_field, id_field, threshold):
                    smiles_field, split_field, id_field, threshold, protein_pdb_field,
                     ligand_pdb_field, ligand_mol2_field):
  """Featurizes raw input data."""
  featurizer = DataFeaturizer(tasks=tasks,
                              smiles_field=smiles_field,
                              split_field=split_field,
                              id_field=id_field,
                              threshold=threshold,
                              protein_pdb_field=protein_pdb_field,
                              ligand_pdb_field=ligand_pdb_field,
                              ligand_mol2_field=ligand_mol2_field,
                              user_specified_features=user_specified_features,
                              verbose=True)
  out = os.path.join(
+12 −0
Original line number Diff line number Diff line
@@ -294,6 +294,10 @@ def write_dataset_single(val, data_dir, feature_types):
def _df_to_numpy(df, feature_types):
  """Transforms a featurized dataset df into standard set of numpy arrays"""
  if not set(feature_types).issubset(df.keys()):
    print("feature_types")
    print(feature_types)
    print("df.keys()")
    print(df.keys())
    raise ValueError(
        "Featurized data does not support requested feature_types.")
  # perform common train/test split across all tasks
@@ -315,9 +319,17 @@ def _df_to_numpy(df, feature_types):

  # Set missing data to have weight zero
  missing = (y == "")
  print("missing")
  print(missing)
  y[missing] = 0.
  w[missing] = 0.

  print("len(sorted_ids)")
  print(len(sorted_ids))
  print("np.shape(x) np.shape(y) np.shape(w)")
  print(np.shape(x))
  print(np.shape(y))
  print(np.shape(w))
  return sorted_ids, x.astype(float), y.astype(float), w.astype(float)


Loading