Unverified Commit 6860317d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2340 from MariBerry/master

Enable passing empty list of tasks for prospective data prediction by classification GraphConvModel 
parents 144ae446 291b5c2f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -980,7 +980,7 @@ class GraphConvModel(KerasModel):
          batch_size=self.batch_size,
          deterministic=deterministic,
          pad_batches=pad_batches):
        if self.mode == 'classification':
        if y_b is not None and self.mode == 'classification':
          y_b = to_one_hot(y_b.flatten(), self.n_classes).reshape(
              -1, self.n_tasks, self.n_classes)
        multiConvMol = ConvMol.agglomerate_mols(X_b)
+19 −0
Original line number Diff line number Diff line
@@ -113,6 +113,25 @@ def test_graph_conv_regression_uncertainty():
  assert mean_std < mean_value


def test_graph_conv_model_no_task():
  tasks, dataset, _, __ = get_dataset('classification', 'GraphConv')
  batch_size = 10
  model = GraphConvModel(
      len(tasks),
      batch_size=batch_size,
      batch_normalize=False,
      mode='classification')
  model.fit(dataset, nb_epoch=20)
  # predict datset with no y (ensured by tasks = [])
  bace_url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv"
  dc.utils.data_utils.download_url(url=bace_url, name="bace_tmp.csv")
  loader = dc.data.CSVLoader(
      tasks=[], smiles_field='mol', featurizer=dc.feat.ConvMolFeaturizer())
  td = loader.featurize(
      os.path.join(dc.utils.data_utils.get_data_dir(), "bace_tmp.csv"))
  model.predict(td)


def test_graph_conv_atom_features():
  tasks, dataset, transformers, metric = get_dataset(
      'regression', 'Raw', num_tasks=1)