Commit 3e1337bc authored by mufeili's avatar mufeili
Browse files

Update

parent 0a5f72f7
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -308,7 +308,7 @@ class AttentiveFPModel(TorchModel):
        ([], labels, weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1:
    if labels[0].shape[-1] == 1 and self.model.mode == 'classification':
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+1 −1
Original line number Diff line number Diff line
@@ -369,7 +369,7 @@ class GATModel(TorchModel):
                                                               weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1:
    if labels[0].shape[-1] == 1 and self.model.mode == 'classification':
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+1 −1
Original line number Diff line number Diff line
@@ -307,7 +307,7 @@ class MPNNModel(TorchModel):
                                                                weights))

    # torch.nn.CrossEntropy expects the last dimension of labels to be non-singleton
    if labels[0].shape[-1] == 1:
    if labels[0].shape[-1] == 1 and self.model.mode == 'classification':
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights