Commit a94fa1d3 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Reformat test_gcn.py with yapf 0.22

parent b20651c4
Loading
Loading
Loading
Loading
+43 −37
Original line number Diff line number Diff line
@@ -22,12 +22,13 @@ except:
def test_gcn_regression():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
  tasks, dataset, transformers, metric = get_dataset('regression',
                                                     featurizer=featurizer)
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = GCNModel(mode='regression',
  model = GCNModel(
      mode='regression',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
@@ -43,7 +44,8 @@ def test_gcn_regression():

  tasks, all_dataset, transformers = load_delaney(featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = dc.models.GCNModel(n_tasks=len(tasks),
  model = dc.models.GCNModel(
      n_tasks=len(tasks),
      graph_conv_layers=[2],
      residual=False,
      predictor_hidden_feats=2)
@@ -55,12 +57,13 @@ def test_gcn_regression():
def test_gcn_classification():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
  tasks, dataset, transformers, metric = get_dataset('classification',
                                                     featurizer=featurizer)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = GCNModel(mode='classification',
  model = GCNModel(
      mode='classification',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
@@ -77,7 +80,8 @@ def test_gcn_classification():
  tasks, all_dataset, transformers = load_bace_classification(
      featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = dc.models.GCNModel(mode='classification',
  model = dc.models.GCNModel(
      mode='classification',
      n_tasks=len(tasks),
      graph_conv_layers=[2],
      residual=False,
@@ -90,13 +94,14 @@ def test_gcn_classification():
def test_gcn_reload():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
  tasks, dataset, transformers, metric = get_dataset('classification',
                                                     featurizer=featurizer)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model_dir = tempfile.mkdtemp()
  model = GCNModel(mode='classification',
  model = GCNModel(
      mode='classification',
      n_tasks=n_tasks,
      number_atom_features=30,
      model_dir=model_dir,
@@ -107,7 +112,8 @@ def test_gcn_reload():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  reloaded_model = GCNModel(mode='classification',
  reloaded_model = GCNModel(
      mode='classification',
      n_tasks=n_tasks,
      number_atom_features=30,
      model_dir=model_dir,