Commit 0a5f72f7 authored by mufeili's avatar mufeili
Browse files

Update

parent 2b6e9b2a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -307,6 +307,7 @@ class AttentiveFPModel(TorchModel):
    _, labels, weights = super(AttentiveFPModel, self)._prepare_batch(
        ([], labels, weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1:
      labels = [lbl.squeeze(-1) for lbl in labels]

+1 −0
Original line number Diff line number Diff line
@@ -368,6 +368,7 @@ class GATModel(TorchModel):
    _, labels, weights = super(GATModel, self)._prepare_batch(([], labels,
                                                               weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1:
      labels = [lbl.squeeze(-1) for lbl in labels]

+2 −1
Original line number Diff line number Diff line
@@ -352,7 +352,8 @@ class GCNModel(TorchModel):
    _, labels, weights = super(GCNModel, self)._prepare_batch(([], labels,
                                                               weights))

    if labels[0].shape[-1] == 1:
    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1 and self.model.mode == 'classification':
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+1 −0
Original line number Diff line number Diff line
@@ -306,6 +306,7 @@ class MPNNModel(TorchModel):
    _, labels, weights = super(MPNNModel, self)._prepare_batch(([], labels,
                                                                weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1:
      labels = [lbl.squeeze(-1) for lbl in labels]