Commit 15f8d3aa authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Minor changes for CANVAS loading.

parent bfb618e9
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -93,8 +93,8 @@ def parse_args(input_args=None):
  train_test_cmd.set_defaults(func=train_test_input)

  # TRAIN FLAGS
  train_cmd = subparsers.add_parser("train",
                  help="Train a model on train data processed by transform.")
  train_cmd = subparsers.add_parser("fit",
                  help="Fit a model to training data.")
  group = train_cmd.add_argument_group("load-and-transform")
  group.add_argument("--task-type", default="classification",
                      choices=["classification", "regression"],
+12 −15
Original line number Diff line number Diff line
@@ -193,9 +193,9 @@ def multitask_to_singletask(dataset):
  sorted_targets = sorted(labels.keys())
  singletask = {target: {} for target in sorted_targets}
  # Populate the singletask datastructures
  sorted_smiles = sorted(dataset.keys())
  for index, smiles in enumerate(sorted_smiles):
    datapoint = dataset[smiles]
  sorted_ids = sorted(dataset.keys())
  for index, id in enumerate(sorted_ids):
    datapoint = dataset[id]
    labels = datapoint["labels"]
    for t_ind, target in enumerate(sorted_targets):
      if labels[target] == -1:
@@ -203,7 +203,7 @@ def multitask_to_singletask(dataset):
      else:
        datapoint_copy = datapoint.copy()
        datapoint_copy["labels"] = {target: labels[target]}
        singletask[target][smiles] = datapoint_copy 
        singletask[target][id] = datapoint_copy 
  return singletask

def split_dataset(dataset, splittype, seed=None):
@@ -221,16 +221,13 @@ def split_dataset(dataset, splittype, seed=None):
def train_test_specified_split(dataset):
  """Split provided data due to splits in origin data."""
  train, test = {}, {}
  for smiles, datapoint in dataset.iteritems():
  for id, datapoint in dataset.iteritems():
    if "split" not in datapoint:
      raise ValueError("Missing required split information.")
    if datapoint["split"].lower() == "train" or datapoint["split"].lower() == "valid":
      train[smiles] = datapoint
    # TODO(rbharath): Add support for validation sets.
    if datapoint["split"].lower() == "train":
      train[id] = datapoint
    elif datapoint["split"].lower() == "test":
      test[smiles] = datapoint
    else:
      raise ValueError("Improper split specified.")
      test[id] = datapoint
  return train, test

def train_test_random_split(dataset, frac_train=.8, seed=None):
@@ -302,13 +299,13 @@ def scaffold_separate(dataset):
    A dictionary of type produced by load_datasets. 
  """
  scaffolds = {}
  for smiles in dataset:
    datapoint = dataset[smiles]
  for id in dataset:
    datapoint = dataset[id]
    scaffold = datapoint["scaffold"]
    if scaffold not in scaffolds:
      scaffolds[scaffold] = [smiles]
      scaffolds[scaffold] = [id]
    else:
      scaffolds[scaffold].append(smiles)
      scaffolds[scaffold].append(id)
  # Sort from largest to smallest scaffold sets 
  return sorted(scaffolds.items(), key=lambda x: -len(x[1]))