Commit ee33411e authored by Ubuntu's avatar Ubuntu
Browse files

Update

parent 84f04f58
Loading
Loading
Loading
Loading
+1 −4
Original line number Diff line number Diff line
@@ -27,10 +27,7 @@ def test_attentivefp_regression():

  # initialize models
  n_tasks = len(tasks)
  model = AttentiveFPModel(
      mode='regression',
      n_tasks=n_tasks,
      batch_size=10)
  model = AttentiveFPModel(mode='regression', n_tasks=n_tasks, batch_size=10)

  # overfit test
  model.fit(dataset, nb_epoch=100)
+63 −61
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy
from deepchem.models.torch_models.torch_model import TorchModel


class AttentiveFP(nn.Module):
  """Model for Graph Property Prediction.

@@ -123,7 +124,8 @@ class AttentiveFP(nn.Module):

    from dgllife.model import AttentiveFPPredictor as DGLAttentiveFPPredictor

    self.model = DGLAttentiveFPPredictor(node_feat_size=number_atom_features,
    self.model = DGLAttentiveFPPredictor(
        node_feat_size=number_atom_features,
        edge_feat_size=number_bond_features,
        num_layers=num_layers,
        num_timesteps=num_timesteps,
@@ -316,6 +318,6 @@ class AttentiveFPModel(TorchModel):
        graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0]
    ]
    inputs = dgl.batch(dgl_graphs).to(self.device)
        _, labels, weights = super(AttentiveFPModel, self)._prepare_batch(([], labels,
                                                                           weights))
    _, labels, weights = super(AttentiveFPModel, self)._prepare_batch(
        ([], labels, weights))
    return inputs, labels, weights