Unverified Commit 50992082 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2328 from mufeili/master

Reshape Labels for Single-Task Datasets for torch.nn.CrossEntropyLoss
parents 0b3b790d 05ae3a79
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -187,7 +187,12 @@ class SparseSoftmaxCrossEntropy(Loss):

  def _compute_tf_loss(self, output, labels):
    import tensorflow as tf

    if len(labels.shape) == len(output.shape):
      labels = tf.squeeze(labels, axis=-1)

    labels = tf.cast(labels, tf.int32)

    return tf.nn.sparse_softmax_cross_entropy_with_logits(labels, output)

  def _create_pytorch_loss(self):
@@ -200,6 +205,9 @@ class SparseSoftmaxCrossEntropy(Loss):
      # This is for API consistency
      if len(output.shape) == 3:
        output = output.permute(0, 2, 1)

      if len(labels.shape) == len(output.shape):
        labels = labels.squeeze(-1)
      return ce_loss(output, labels.long())

    return loss
+27 −0
Original line number Diff line number Diff line
@@ -34,6 +34,19 @@ def test_attentivefp_regression():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_delaney

  tasks, all_dataset, transformers = load_delaney(featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = AttentiveFPModel(
      mode='regression',
      n_tasks=len(tasks),
      num_layers=1,
      num_timesteps=1,
      graph_feat_size=2)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
@@ -56,6 +69,20 @@ def test_attentivefp_classification():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_bace_classification

  tasks, all_dataset, transformers = load_bace_classification(
      featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = AttentiveFPModel(
      mode='classification',
      n_tasks=len(tasks),
      num_layers=1,
      num_timesteps=1,
      graph_feat_size=2)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
+29 −0
Original line number Diff line number Diff line
@@ -39,6 +39,20 @@ def test_gat_regression():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_delaney

  tasks, all_dataset, transformers = load_delaney(featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = dc.models.GATModel(
      mode='regression',
      n_tasks=len(tasks),
      graph_attention_layers=[2],
      n_attention_heads=1,
      residual=False,
      predictor_hidden_feats=2)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
@@ -62,6 +76,21 @@ def test_gat_classification():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_bace_classification

  tasks, all_dataset, transformers = load_bace_classification(
      featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = dc.models.GATModel(
      mode='classification',
      n_tasks=len(tasks),
      graph_attention_layers=[2],
      n_attention_heads=1,
      residual=False,
      predictor_hidden_feats=2)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
+26 −0
Original line number Diff line number Diff line
@@ -39,6 +39,18 @@ def test_gcn_regression():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_delaney

  tasks, all_dataset, transformers = load_delaney(featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = dc.models.GCNModel(
      n_tasks=len(tasks),
      graph_conv_layers=[2],
      residual=False,
      predictor_hidden_feats=2)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
@@ -62,6 +74,20 @@ def test_gcn_classification():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_bace_classification

  tasks, all_dataset, transformers = load_bace_classification(
      featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = dc.models.GCNModel(
      mode='classification',
      n_tasks=len(tasks),
      graph_conv_layers=[2],
      residual=False,
      predictor_hidden_feats=2)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
+12 −0
Original line number Diff line number Diff line
@@ -186,6 +186,12 @@ class TestLosses(unittest.TestCase):
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

    labels = tf.constant([[1], [0]])
    result = loss._compute_tf_loss(outputs, labels).numpy()
    softmax = np.exp(y) / np.expand_dims(np.sum(np.exp(y), axis=1), 1)
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_sparse_softmax_cross_entropy_pytorch(self):
    """Test SparseSoftmaxCrossEntropy."""
@@ -198,6 +204,12 @@ class TestLosses(unittest.TestCase):
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

    labels = torch.tensor([[1], [0]])
    result = loss._create_pytorch_loss()(outputs, labels).numpy()
    softmax = np.exp(y) / np.expand_dims(np.sum(np.exp(y), axis=1), 1)
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_VAE_ELBO_tf(self):
    """."""
Loading