Commit f5169457 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Some more loading changes

parent 743a31cd
Loading
Loading
Loading
Loading
+30 −29
Original line number Diff line number Diff line
@@ -63,10 +63,6 @@ def load_molecules(paths, feature_types=["fingerprints"]):
  Returns a dictionary that maps smiles strings to dicts that contain
  fingerprints, smiles strings, scaffolds, mol_ids.

  TODO(rbharath): This function assumes that all datapoints are uniquely keyed
  by smiles strings. This doesn't hold true for the pdbbind dataset. Need to find
  a more general indexing mechanism.

  Parameters
  ----------
  paths: list
@@ -126,9 +122,7 @@ def get_target_names(paths, target_dir_name="targets"):
def load_assays(paths, target_dir_name="targets"):
  """Load regression dataset labels from assays.

  Returns a dictionary that maps smiles strings to label vectors.

  TODO(rbharath): Remove the use of smiles as unique identifier
  Returns a dictionary that maps mol_id's to label vectors.

  Parameters
  ----------
@@ -150,33 +144,33 @@ def load_assays(paths, target_dir_name="targets"):
        contents = pickle.load(f)
        if "prediction" not in contents:
          raise ValueError("Prediction Endpoint Missing.")
        for ind, smiles in enumerate(contents["smiles"]):
        for ind, id in enumerate(contents["mol_id"]):
          measurement = contents["prediction"][ind]
          if "split" is not None:
            splits[smiles] = contents["split"][ind]
            splits[id] = contents["split"][ind]
          else:
            splits[smiles] = None
          # TODO(rbharath): There is some amount of duplicate collisions
          # due to choice of smiles generation. Look into this more
          # carefully and see if the underlying issues are fundamental..
            splits[id] = None
          try:
            if measurement is None or np.isnan(measurement):
              continue
          except TypeError:
            continue
          if smiles not in labels:
            labels[smiles] = {}
          if id not in labels:
            labels[id] = {}
            # Ensure that each target has some entry in dict.
            for name in target_names:
              # Set all targets to invalid for now.
              labels[smiles][name] = -1
          labels[smiles][target_name] = measurement 
              labels[id][name] = -1
          labels[id][target_name] = measurement 
  print "load_assays()"
  print "labels"
  print labels
  return labels, splits

def load_datasets(paths, target_dir_name="targets", feature_types=["fingerprints"]):
  """Load both labels and fingerprints.

  Returns a dictionary that maps smiles to pairs of (fingerprint, labels)
  Returns a dictionary that maps mol_id's to pairs of (fingerprint, labels)
  where labels is itself a dict that maps target-names to labels.

  Parameters
@@ -186,15 +180,22 @@ def load_datasets(paths, target_dir_name="targets", feature_types=["fingerprints
  """
  data = {}
  molecules = load_molecules(paths, feature_types)
  print "load_datasets()"
  print "len(molecules)"
  print len(molecules)
  labels, splits = load_assays(paths, target_dir_name)
  for ind, smiles in enumerate(molecules):
    if smiles not in labels:
  print "len(labels)"
  print len(labels)
  for ind, id in enumerate(molecules):
    if id not in labels:
      continue
    mol = molecules[smiles]
    data[smiles] = {"fingerprint": mol["fingerprint"],
    mol = molecules[id]
    data[id] = {"fingerprint": mol["fingerprint"],
                "scaffold": mol["scaffold"],
                    "labels": labels[smiles],
                    "split": splits[smiles]}
                "labels": labels[id],
                "split": splits[id]}
  print "len(data)"
  print len(data)
  return data

def ensure_balanced(y, W):
@@ -229,10 +230,10 @@ def load_and_transform_dataset(paths, input_transforms, output_transforms,
      weight_positives=weight_positives)
  X = transform_inputs(X, input_transforms)
  trans_data = {}
  sorted_smiles = sorted(dataset.keys())
  sorted_ids = sorted(dataset.keys())
  sorted_targets = sorted(output_transforms.keys())
  for s_index, smiles in enumerate(sorted_smiles):
    datapoint = dataset[smiles]
  for s_index, id in enumerate(sorted_ids):
    datapoint = dataset[id]
    labels = {}
    for t_index, target in enumerate(sorted_targets):
      if W[s_index][t_index] == 0:
@@ -242,5 +243,5 @@ def load_and_transform_dataset(paths, input_transforms, output_transforms,
    datapoint["labels"] = labels
    datapoint["fingerprint"] = X[s_index]

    trans_data[smiles] = datapoint 
    trans_data[id] = datapoint 
  return trans_data