Commit 2b6e9b2a authored by mufeili's avatar mufeili
Browse files

Update

parent fcca3bde
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -306,4 +306,8 @@ class AttentiveFPModel(TorchModel):
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(AttentiveFPModel, self)._prepare_batch(
        ([], labels, weights))

    if labels[0].shape[-1] == 1:
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+4 −0
Original line number Diff line number Diff line
@@ -367,4 +367,8 @@ class GATModel(TorchModel):
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(GATModel, self)._prepare_batch(([], labels,
                                                               weights))

    if labels[0].shape[-1] == 1:
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+4 −0
Original line number Diff line number Diff line
@@ -351,4 +351,8 @@ class GCNModel(TorchModel):
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(GCNModel, self)._prepare_batch(([], labels,
                                                               weights))

    if labels[0].shape[-1] == 1:
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights
+4 −0
Original line number Diff line number Diff line
@@ -305,4 +305,8 @@ class MPNNModel(TorchModel):
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(MPNNModel, self)._prepare_batch(([], labels,
                                                                weights))

    if labels[0].shape[-1] == 1:
      labels = [lbl.squeeze(-1) for lbl in labels]

    return inputs, labels, weights