Commit 1c3a0ad4 authored by mufeili's avatar mufeili
Browse files

Update

parent 94239068
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -200,6 +200,8 @@ class SparseSoftmaxCrossEntropy(Loss):
      # This is for API consistency
      if len(output.shape) == 3:
        output = output.permute(0, 2, 1)
      if labels.shape[-1] == 1:
        labels = labels.squeeze(-1)
      return ce_loss(output, labels.long())

    return loss
+6 −0
Original line number Diff line number Diff line
@@ -198,6 +198,12 @@ class TestLosses(unittest.TestCase):
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

    labels = torch.tensor([[1, 0]])
    result = loss._create_pytorch_loss()(outputs, labels).numpy()
    softmax = np.exp(y) / np.expand_dims(np.sum(np.exp(y), axis=1), 1)
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_VAE_ELBO_tf(self):
    """."""
+0 −5
Original line number Diff line number Diff line
@@ -306,9 +306,4 @@ class AttentiveFPModel(TorchModel):
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(AttentiveFPModel, self)._prepare_batch(
        ([], labels, weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1 and self.model.mode == 'classification':
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+0 −5
Original line number Diff line number Diff line
@@ -367,9 +367,4 @@ class GATModel(TorchModel):
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(GATModel, self)._prepare_batch(([], labels,
                                                               weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1 and self.model.mode == 'classification':
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+0 −5
Original line number Diff line number Diff line
@@ -351,9 +351,4 @@ class GCNModel(TorchModel):
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(GCNModel, self)._prepare_batch(([], labels,
                                                               weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1 and self.model.mode == 'classification':
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
Loading