Commit 9fcc7a49 authored by nd-02110114's avatar nd-02110114
Browse files

🐛 fix gat classification bug

parent e22deb33
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
@@ -192,7 +192,17 @@ class SparseSoftmaxCrossEntropy(Loss):

  def _create_pytorch_loss(self):
    import torch
    return torch.nn.CrossEntropyLoss(reduction='none')
    ce_loss = torch.nn.CrossEntropyLoss(reduction='mean')

    def loss(output, labels):
      # Convert (batch_size, tasks, classes) to (batch_size, classes, tasks)
      # CrossEntropyLoss only supports (batch_size, classes, tasks)
      # This is for API consistency
      if len(output.shape) == 3:
        output = output.permute(0, 2, 1)
      return ce_loss(output, labels.long())

    return loss


def _make_tf_shapes_consistent(output, labels):
@@ -251,3 +261,9 @@ def _ensure_float(output, labels):
  if labels.dtype not in (tf.float32, tf.float64):
    labels = tf.cast(labels, tf.float32)
  return (output, labels)


def _ensure_long(labels):
  """Make sure the outputs are Long types."""
  labels = [val.long() for val in labels]
  return labels
+9 −5
Original line number Diff line number Diff line
@@ -26,10 +26,10 @@ def test_gat_regression():
      mode='regression', n_tasks=n_tasks, batch_size=4, learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=300)
  scores = model.evaluate(dataset, [metric], transformers)
  # TODO: check this asseration is correct or not
  assert scores['mean_absolute_error'] < 1.0
  assert scores['mean_absolute_error'] < 0.2


@unittest.skipIf(not has_pytorch_and_pyg,
@@ -43,9 +43,13 @@ def test_gat_classification():
  # initialize models
  n_tasks = len(tasks)
  model = GATModel(
      mode='classification', n_tasks=n_tasks, batch_size=10, learning_rate=0.001)
      mode='classification',
      n_tasks=n_tasks,
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=10)
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=150)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.9
+22 −21
Original line number Diff line number Diff line
@@ -52,11 +52,11 @@ class GAT(nn.Module):
  def __init__(
      self,
      in_node_dim: int = 39,
      hidden_node_dim: int = 64,
      heads: int = 4,
      hidden_node_dim: int = 32,
      heads: int = 1,
      dropout: float = 0.0,
      num_conv: int = 3,
      predictor_hidden_feats: int = 32,
      num_conv: int = 2,
      predictor_hidden_feats: int = 64,
      n_tasks: int = 1,
      mode: str = 'classification',
      n_classes: int = 2,
@@ -67,19 +67,19 @@ class GAT(nn.Module):
    in_node_dim: int, default 39
      The length of the initial node feature vectors. The 39 is
      based on `MolGraphConvFeaturizer`.
    hidden_node_dim: int, default 64
    hidden_node_dim: int, default 32
      The length of the hidden node feature vectors.
    heads: int, default 4
    heads: int, default 1
      The number of multi-head-attentions.
    dropout: float, default 0.0
      The dropout probability for each convolutional layer.
    num_conv: int, default 3
    num_conv: int, default 2
      The number of convolutional layers.
    predictor_hidden_feats: int, default 32
      The size for hidden representations in the output MLP predictor, default to 32.
    predictor_hidden_feats: int, default 64
      The size for hidden representations in the output MLP predictor, default to 64.
    n_tasks: int, default 1
      The number of the output size, default to 1.
    mode: str, default 'regression'
    mode: str, default 'classification'
      The model type, 'classification' or 'regression'.
    n_classes: int, default 2
      The number of classes to predict (only used in classification mode).
@@ -131,7 +131,7 @@ class GAT(nn.Module):

    # pooling
    graph_feat = self.pooling(node_feat, data.batch)
    graph_feat = F.relu(self.fc(graph_feat))
    graph_feat = F.leaky_relu(self.fc(graph_feat))
    out = self.out(graph_feat)

    if self.mode == 'regression':
@@ -140,7 +140,7 @@ class GAT(nn.Module):
      logits = out.view(-1, self.n_tasks, self.n_classes)
      # for n_tasks == 1 case
      logits = torch.squeeze(logits)
      proba = F.softmax(logits)
      proba = F.softmax(logits, dim=-1)
      return proba, logits


@@ -177,11 +177,11 @@ class GATModel(TorchModel):

  def __init__(self,
               in_node_dim: int = 39,
               hidden_node_dim: int = 64,
               heads: int = 4,
               hidden_node_dim: int = 32,
               heads: int = 1,
               dropout: float = 0.0,
               num_conv: int = 3,
               predictor_hidden_feats: int = 32,
               num_conv: int = 2,
               predictor_hidden_feats: int = 64,
               n_tasks: int = 1,
               mode: str = 'regression',
               n_classes: int = 2,
@@ -194,16 +194,16 @@ class GATModel(TorchModel):
    in_node_dim: int, default 39
      The length of the initial node feature vectors. The 39 is
      based on `MolGraphConvFeaturizer`.
    hidden_node_dim: int, default 64
    hidden_node_dim: int, default 32
      The length of the hidden node feature vectors.
    heads: int, default 4
    heads: int, default 1
      The number of multi-head-attentions.
    dropout: float, default 0.0
      The dropout probability for each convolutional layer.
    num_conv: int, default 3
    num_conv: int, default 2
      The number of convolutional layers.
    predictor_hidden_feats: int, default 32
      The size for hidden representations in the output MLP predictor, default to 32.
    predictor_hidden_feats: int, default 64
      The size for hidden representations in the output MLP predictor, default to 64.
    n_tasks: int, default 1
      The number of the output size, default to 1.
    mode: str, default 'regression'
@@ -249,6 +249,7 @@ class GATModel(TorchModel):
    inputs, labels, weights = batch
    pyg_graphs = [graph.to_pyg_graph() for graph in inputs[0]]
    inputs = Batch.from_data_list(pyg_graphs)
    inputs = inputs.to(self.device)
    _, labels, weights = super(GATModel, self)._prepare_batch(([], labels,
                                                               weights))
    return inputs, labels, weights