Commit 10357d74 authored by nd-02110114's avatar nd-02110114
Browse files

🔥 fix lint

parent 61d3dec7
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -88,7 +88,8 @@ class GAT(nn.Module):
    try:
      from torch_geometric.nn import GATConv, global_mean_pool
    except:
      raise ImportError("This class requires PyTorch Geometric to be installed.")
      raise ImportError(
          "This class requires PyTorch Geometric to be installed.")

    self.n_tasks = n_tasks
    self.mode = mode
@@ -246,7 +247,8 @@ class GATModel(TorchModel):
    try:
      from torch_geometric.data import Batch
    except:
      raise ImportError("This class requires PyTorch Geometric to be installed.")
      raise ImportError(
          "This class requires PyTorch Geometric to be installed.")

    inputs, labels, weights = batch
    pyg_graphs = [graph.to_pyg_graph() for graph in inputs[0]]