Unverified Commit e64a6e0f authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2138 from nd-02110114/gat-add-mode

Fix GAT and CGCNN support for GPU
parents b78de6f0 ee0b13eb
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ class Featurizer(object):
    >>> dc.feat.CircularFingerprint(size=1024, radius=4)
    CircularFingerprint[radius=4, size=1024, chiral=False, bonds=True, features=False, sparse=False, smiles=False]
    >>> dc.feat.CGCNNFeaturizer()
    CGCNNFeaturizer[radius=8.0, max_neighbors=8, step=0.2]
    CGCNNFeaturizer[radius=8.0, max_neighbors=12, step=0.2]
    """
    args_spec = inspect.getfullargspec(self.__init__)  # type: ignore
    args_names = [arg for arg in args_spec.args if arg != 'self']
@@ -277,7 +277,7 @@ class MolecularFeaturizer(Featurizer):
        features.append(self._featurize(mol))
      except:
        logger.warning(
            "Failed to featurize datapoint %d. Appending empty array")
            "Failed to featurize datapoint %d. Appending empty array", i)
        features.append(np.array([]))

    features = np.asarray(features)
+2 −2
Original line number Diff line number Diff line
@@ -50,14 +50,14 @@ class CGCNNFeaturizer(MaterialStructureFeaturizer):

  def __init__(self,
               radius: float = 8.0,
               max_neighbors: float = 8,
               max_neighbors: float = 12,
               step: float = 0.2):
    """
    Parameters
    ----------
    radius: float (default 8.0)
      Radius of sphere for finding neighbors of atoms in unit cell.
    max_neighbors: int (default 8)
    max_neighbors: int (default 12)
      Maximum number of neighbors to consider when constructing graph.
    step: float (default 0.2)
      Step size for Gaussian filter. This value is used when building edge features.
+11 −1
Original line number Diff line number Diff line
@@ -192,7 +192,17 @@ class SparseSoftmaxCrossEntropy(Loss):

  def _create_pytorch_loss(self):
    import torch
    return torch.nn.CrossEntropyLoss(reduction='none')
    ce_loss = torch.nn.CrossEntropyLoss(reduction='none')

    def loss(output, labels):
      # Convert (batch_size, tasks, classes) to (batch_size, classes, tasks)
      # CrossEntropyLoss only supports (batch_size, classes, tasks)
      # This is for API consistency
      if len(output.shape) == 3:
        output = output.permute(0, 2, 1)
      return ce_loss(output, labels.long())

    return loss


def _make_tf_shapes_consistent(output, labels):
+28 −6
Original line number Diff line number Diff line
import unittest

from deepchem.feat import MolGraphConvFeaturizer
from deepchem.models import GATModel, losses
from deepchem.models import GATModel
from deepchem.models.tests.test_graph_models import get_dataset

try:
@@ -14,19 +14,41 @@ except:

@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
def test_gat_classification():
def test_gat_regression():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = GATModel(mode='regression', n_tasks=n_tasks, batch_size=10)

  # overfit test
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=300)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5


@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
def test_gat_classification():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = GATModel(
      n_tasks=n_tasks, loss=losses.L2Loss(), batch_size=4, learning_rate=0.001)
      mode='classification',
      n_tasks=n_tasks,
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=150)
  scores = model.evaluate(dataset, [metric], transformers)
  # TODO: check this asseration is correct or not
  assert scores['mean_absolute_error'] < 1.0
  assert scores['mean-roc_auc_score'] >= 0.9
+36 −25
Original line number Diff line number Diff line
@@ -32,9 +32,9 @@ class CGCNNLayer(nn.Module):
  >>> print(type(cgcnn_dgl_graph))
  <class 'dgl.heterograph.DGLHeteroGraph'>
  >>> layer = CGCNNLayer(hidden_node_dim=92, edge_dim=41)
  >>> update_graph = layer(cgcnn_dgl_graph)
  >>> print(type(update_graph))
  <class 'dgl.heterograph.DGLHeteroGraph'>
  >>> node_feats = cgcnn_dgl_graph.ndata.pop('x')
  >>> edge_feats = cgcnn_dgl_graph.edata.pop('edge_attr')
  >>> new_node_feats, new_edge_feats = layer(cgcnn_dgl_graph, node_feats, edge_feats)

  Notes
  -----
@@ -57,40 +57,48 @@ class CGCNNLayer(nn.Module):
    """
    super(CGCNNLayer, self).__init__()
    z_dim = 2 * hidden_node_dim + edge_dim
    self.linear_with_sigmoid = nn.Linear(z_dim, hidden_node_dim)
    self.linear_with_softplus = nn.Linear(z_dim, hidden_node_dim)
    self.batch_norm = nn.BatchNorm1d(hidden_node_dim) if batch_norm else None
    liner_out_dim = 2 * hidden_node_dim
    self.linear = nn.Linear(z_dim, liner_out_dim)
    self.batch_norm = nn.BatchNorm1d(liner_out_dim) if batch_norm else None

  def message_func(self, edges):
    z = torch.cat(
        [edges.src['x'], edges.dst['x'], edges.data['edge_attr']], dim=1)
    gated_z = torch.sigmoid(self.linear_with_sigmoid(z))
    message_z = F.softplus(self.linear_with_softplus(z))
    return {'gated_z': gated_z, 'message_z': message_z}
    z = self.linear(z)
    if self.batch_norm is not None:
      z = self.batch_norm(z)
    gated_z, message_z = z.chunk(2, dim=1)
    gated_z = torch.sigmoid(gated_z)
    message_z = F.softplus(message_z)
    return {'message': gated_z * message_z}

  def reduce_func(self, nodes):
    new_h = nodes.data['x'] + torch.sum(
        nodes.mailbox['gated_z'] * nodes.mailbox['message_z'], dim=1)
    return {'x': new_h}
    nbr_sumed = torch.sum(nodes.mailbox['message'], dim=1)
    new_x = F.softplus(nodes.data['x'] + nbr_sumed)
    return {'new_x': new_x}

  def forward(self, dgl_graph):
    """Update node representaions.
  def forward(self, dgl_graph, node_feats, edge_feats):
    """Update node representations.

    Parameters
    ----------
    dgl_graph: DGLGraph
      DGLGraph for a batch of graphs. The graph expects that the node features
      are stored in `ndata['x']`, and the edge features are stored in `edata['edge_attr']`.
      DGLGraph for a batch of graphs.
    node_feats: torch.Tensor
      The node features. The shape is `(N, hidden_node_dim)`.
    edge_feats: torch.Tensor
      The edge features. The shape is `(N, hidden_node_dim)`.

    Returns
    -------
    dgl_graph: DGLGraph
      DGLGraph for a batch of updated graphs.
    node_feats: torch.Tensor
      The updated node features. The shape is `(N, hidden_node_dim)`.
    """
    dgl_graph.ndata['x'] = node_feats
    dgl_graph.edata['edge_attr'] = edge_feats
    dgl_graph.update_all(self.message_func, self.reduce_func)
    if self.batch_norm is not None:
      dgl_graph.ndata['x'] = self.batch_norm(dgl_graph.ndata['x'])
    return dgl_graph
    node_feats = dgl_graph.ndata.pop('new_x')
    return node_feats


class CGCNN(nn.Module):
@@ -215,15 +223,18 @@ class CGCNN(nn.Module):
    """
    graph = dgl_graph
    # embedding node features
    graph.ndata['x'] = self.embedding(graph.ndata['x'])
    node_feats = graph.ndata.pop('x')
    edge_feats = graph.edata.pop('edge_attr')
    node_feats = self.embedding(node_feats)

    # convolutional layer
    for conv in self.conv_layers:
      graph = conv(graph)
      node_feats = conv(graph, node_feats, edge_feats)

    # pooling
    graph_feat = self.pooling(graph, 'x')
    graph_feat = self.fc(graph_feat)
    graph.ndata['updated_x'] = node_feats
    graph_feat = F.softplus(self.pooling(graph, 'updated_x'))
    graph_feat = F.softplus(self.fc(graph_feat))
    out = self.out(graph_feat)

    if self.mode == 'regression':
Loading