Commit 762552e6 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #654 from patrickhop/upgrades

Upgraded MPNN [contrib]
parents 0efa3f60 24b8f9fc
Loading
Loading
Loading
Loading
+21 −10
Original line number Original line Diff line number Diff line
# Message Passing Neural Networks
# Message Passing Neural Networks


MPNNs aim to generalize molecular machine learning models that operate on graph-valued inputs. Graph-Convolutions [https://arxiv.org/abs/1509.09292] and Weaves [https://arxiv.org/abs/1603.00856] (among others) can be recast into this framework [https://arxiv.org/abs/1704.01212]
MPNNs aim to generalize molecular machine learning models that operate on graph-valued inputs. Graph-Convolutions [https://arxiv.org/abs/1509.09292] and Weaves \
[https://arxiv.org/abs/1603.00856] (among others) can be recast into this framework [https://arxiv.org/abs/1704.01212]


The premise is that the featurization of arbitrary chemical multigraphs can be broken down into a message function, vertex-update function, and a readout function that is invariant to graph isomorphisms. All functions must be subdifferentiable to preserve gradient-flow and ideally are learnable too.
The premise is that the featurization of arbitrary chemical multigraphs can be broken down into a message function, vertex-update function, and a readout functi\
on that is invariant to graph isomorphisms. All functions must be subdifferentiable to preserve gradient-flow and ideally are learnable as well


Models of this style introduce an additional parameter **T**, which is the number of iterations for the message-passing stage. Values greater than 4 don't seem to improve performance.
Models of this style introduce an additional parameter **T**, which is the number of iterations for the message-passing stage. Values greater than 4 don't seem \
to improve performance.


Requires PyTorch.
##MPNN-S Variant
 MPNNs do provide a nice mathematical framework that can capture modern molecular machine learning algorithms we work with today. One criticism of this algorithm class is that training is slow, due to the sheer number of training iterations required for convergence - at batch size 20 on QM9, the MPNN authors trained for 540 epochs.
 
 
| Dataset | Examples | MP-DNN Val R2 (Index Split) |
This can be improved significantly by using batch normalization, or more interestingly, the new SELU activation [https://arxiv.org/pdf/1706.02515.pdf]. In order to use SELUs straight through the system, we dropped the GRU unit [https://arxiv.org/pdf/1412.3555.pdf] the authors used in favor of a SELU activated fully-connected neural network for each time step **T**. This modified approach now achieves peak performance in as little as 60 epochs on most molecular machine learning datasets.
| ------ | ------ | ------ |
| Delaney | 1102 | .801 |


## Running Code
MPNN-S sets new records on the Delaney & PPB datasets:

| Dataset | Num Examples | MP-DNN Val R2 [Scaffold Split] | GraphConv Val R2 [Scaffold Split] |
| ------ | ------ | ------ | ------ |
| Delaney | 1102 | **.820** | .606 |
| PPB | 1600 | **.427** | .381 |
| Clearance | 838 | **.32** | .28 |


## Run Code
```sh
```sh
$ python mpnn_baseline.py
$ python mpnn.py
```
```


License
License

contrib/mpnn/donkey.py

0 → 100644
+101 −0
Original line number Original line Diff line number Diff line
# 2017 DeepCrystal Technologies - Patrick Hop
#
# Data loading a splitting file
#
# MIT License - have fun!!
# ===========================================================

import os
import random
from collections import OrderedDict

import deepchem as dc
from deepchem.utils import ScaffoldGenerator
from deepchem.utils.save import log
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from sklearn import preprocessing
from sklearn.decomposition import TruncatedSVD

from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem

random.seed(2)
np.random.seed(2)
torch.manual_seed(2)

def generate_scaffold(smiles, include_chirality=False):
  """Compute the Bemis-Murcko scaffold for a SMILES string."""
  mol = Chem.MolFromSmiles(smiles)
  engine = ScaffoldGenerator(include_chirality=include_chirality)
  scaffold = engine.get_scaffold(mol)
  return scaffold

def split(dataset,
          frac_train=.80,
          frac_valid=.10,
          frac_test=.10,
          log_every_n=1000):
  """
  Splits internal compounds into train/validation/test by scaffold.
  """
  np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
  scaffolds = {}
  log("About to generate scaffolds", True)
  data_len = len(dataset)
  
  for ind, smiles in enumerate(dataset):
    if ind % log_every_n == 0:
      log("Generating scaffold %d/%d" % (ind, data_len), True)
    scaffold = generate_scaffold(smiles)
    if scaffold not in scaffolds:
      scaffolds[scaffold] = [ind]
    else:
      scaffolds[scaffold].append(ind)

  scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
  scaffold_sets = [
    scaffold_set
    for (scaffold, scaffold_set) in sorted(
        scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
  ]
  train_cutoff = frac_train * len(dataset)
  valid_cutoff = (frac_train + frac_valid) * len(dataset)
  train_inds, valid_inds, test_inds = [], [], []
  log("About to sort in scaffold sets", True)
  for scaffold_set in scaffold_sets:
    if len(train_inds) + len(scaffold_set) > train_cutoff:
      if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
        test_inds += scaffold_set
      else:
        valid_inds += scaffold_set
    else:
      train_inds += scaffold_set
  return train_inds, valid_inds, test_inds

def load_dataset(filename, whiten=False):
  f = open(filename, 'r')
  features = []
  labels = []
  tracer = 0
  for line in f:
    if tracer == 0:
      tracer += 1
      continue
    splits =  line[:-1].split(',')
    features.append(splits[-1])
    labels.append(float(splits[-2]))
  features = np.array(features)
  labels = np.array(labels, dtype='float32').reshape(-1, 1)

  train_ind, val_ind, test_ins = split(features)

  train_features = np.take(features, train_ind)
  train_labels = np.take(labels, train_ind)
  val_features = np.take(features, val_ind)
  val_labels = np.take(labels, val_ind)
  
  return train_features, train_labels, val_features, val_labels
+182 −0
Original line number Original line Diff line number Diff line
# 2017 DeepCrystal Technologies - Patrick Hop
# 2017 DeepCrystal Technologies - Patrick Hop
#
#
# Message Passing Neural Network for Chemical Multigraphs
# Message Passing Neural Network SELU [MPNN-S] for Chemical Multigraphs
#
#
# MIT License - have fun!!
# MIT License - have fun!!
# ===========================================================
# ===========================================================


import math

import deepchem as dc
import deepchem as dc
from rdkit import Chem, DataStructs
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import AllChem
@@ -17,81 +19,71 @@ import torch.nn.functional as F


from sklearn.metrics import r2_score
from sklearn.metrics import r2_score
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn import preprocessing
import numpy as np
import numpy as np


import random
import random
from collections import OrderedDict
from collections import OrderedDict
from scipy.stats import pearsonr

import donkey


random.seed(2)
random.seed(2)
torch.manual_seed(2)
torch.manual_seed(2)
np.random.seed(2)
np.random.seed(2)


T = 4
DATASET = 'az_ppb.csv'
BATCH_SIZE = 64
print(DATASET)
MAXITER = 2000

T = 3
BATCH_SIZE = 48
MAXITER = 40000
LIMIT = 0
LR = 5e-4


#A = {}
R = nn.Linear(150, 128)
# valid_bonds = {'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}
U = {0: nn.Linear(156, 75), 1: nn.Linear(156, 75), 2: nn.Linear(156, 75)}
#for valid_bond in valid_bonds:
V = {0: nn.Linear(75, 75), 1: nn.Linear(75, 75), 2: nn.Linear(75, 75)}
#  A[valid_bond] = nn.Linear(75, 75)
E = nn.Linear(6, 6)


R = nn.Linear(75, 128)
def adjust_learning_rate(optimizer, epoch):
#GRU = nn.GRU(150, 75, 1)
  """Sets the learning rate to the initial LR decayed by .8 every 5 epochs"""
U = nn.Linear(150, 75)
  lr = LR * (0.9 ** (epoch // 10))
  print('new lr [%.5f]' % lr)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr


def load_dataset():
def load_dataset():
  f = open('delaney-processed.csv', 'r')
  train_features, train_labels, val_features, val_labels = donkey.load_dataset(DATASET)
  features = []

  labels = []
  scaler = preprocessing.StandardScaler().fit(train_labels)
  tracer = 0
  train_labels = scaler.transform(train_labels)
  for line in f:
  val_labels = scaler.transform(val_labels)
    if tracer == 0:
      tracer += 1
      continue
    splits =  line[:-1].split(',')
    features.append(splits[-1])
    labels.append(float(splits[-2]))

  train_features = np.array(features[:900])
  train_labels = np.array(labels[:900])
  val_features = np.array(features[900:1100])
  val_labels = np.array(labels[900:1100])


  train_labels = Variable(torch.FloatTensor(train_labels), requires_grad=False)
  train_labels = Variable(torch.FloatTensor(train_labels), requires_grad=False)
  val_labels = Variable(torch.FloatTensor(val_labels), requires_grad=False)
  val_labels = Variable(torch.FloatTensor(val_labels), requires_grad=False)
  
  
  return train_features, train_labels, val_features, val_labels
  return train_features, train_labels, val_features, val_labels


def readout(h):
def readout(h, h2):
  reads = map(lambda x: F.relu(R(h[x])), h.keys())
  catted_reads = map(lambda x: torch.cat([h[x[0]], h2[x[1]]], 1), zip(h2.keys(), h.keys()))
  activated_reads = map(lambda x: F.selu( R(x) ), catted_reads)
  readout = Variable(torch.zeros(1, 128))
  readout = Variable(torch.zeros(1, 128))
  for read in reads:
  for read in activated_reads:
    readout = readout + read
    readout = readout + read
  return readout
  return F.tanh( readout )


def message_pass(g, h, k):
def message_pass(g, h, k):
  #flow_delta = Variable(torch.zeros(1, 1))
  #h_t = Variable(torch.zeros(1, 1, 75))
  for v in g.keys():
  for v in g.keys():
    neighbors = g[v]
    neighbors = g[v]
    for neighbor in neighbors:
    for neighbor in neighbors:
      e_vw = neighbor[0]
      e_vw = neighbor[0] # feature variable
      w = neighbor[1]
      w = neighbor[1]
      #bond_type = e_vw.GetBondType()
      #A_vw = A[str(e_vw.GetBondType())]

      m_v = h[w]
      catted = torch.cat([h[v], m_v], 1)
      #gru_act, h_t = GRU(catted.view(1, 1, 150), h_t)
      
      
      # measure convergence
      m_w = V[k](h[w])
      #pdist = nn.PairwiseDistance(2)
      m_e_vw = E(e_vw)
      #flow_delta = flow_delta + torch.sum(pdist(gru_act.view(1, 75), h[v]))
      reshaped = torch.cat( (h[v], m_w, m_e_vw), 1)
      
      h[v] = F.selu(U[k](reshaped))
      #h[v] = gru_act.view(1, 75)
      h[v] = U(catted)

  #print '    flow delta [%i] [%f]' % (k, flow_delta.data.numpy()[0])


def construct_multigraph(smile):
def construct_multigraph(smile):
  g = OrderedDict({})
  g = OrderedDict({})
@@ -104,6 +96,8 @@ def construct_multigraph(smile):
    for j in xrange(0, molecule.GetNumAtoms()):
    for j in xrange(0, molecule.GetNumAtoms()):
      e_ij = molecule.GetBondBetweenAtoms(i, j)
      e_ij = molecule.GetBondBetweenAtoms(i, j)
      if e_ij != None:
      if e_ij != None:
        e_ij =  map(lambda x: 1 if x == True else 0, dc.feat.graph_features.bond_features(e_ij)) # ADDED edge feat
        e_ij = Variable(torch.FloatTensor(e_ij).view(1, 6))
        atom_j = molecule.GetAtomWithIdx(j)
        atom_j = molecule.GetAtomWithIdx(j)
        if i not in g:
        if i not in g:
          g[i] = []
          g[i] = []
@@ -113,99 +107,76 @@ def construct_multigraph(smile):


train_smiles, train_labels, val_smiles, val_labels = load_dataset()
train_smiles, train_labels, val_smiles, val_labels = load_dataset()


# training loop
linear = nn.Linear(128, 1)
linear = nn.Linear(128, 1)
params = [#{'params': A['SINGLE'].parameters()},
params = [{'params': R.parameters()},
         #{'params': A['DOUBLE'].parameters()},
         {'params': U[0].parameters()},
         #{'params': A['TRIPLE'].parameters()},
         {'params': U[1].parameters()},
         #{'params': A['AROMATIC'].parameters()},
         {'params': U[2].parameters()},
         {'params': R.parameters()},
         {'params': E.parameters()},
         #{'params': GRU.parameters()},
         {'params': V[0].parameters()},
         {'params': U.parameters()},
         {'params': V[1].parameters()},
         {'params': V[2].parameters()},
         {'params': linear.parameters()}]
         {'params': linear.parameters()}]


optimizer = optim.SGD(params, lr=1e-5, momentum=0.9)
num_epoch = 0
optimizer = optim.Adam(params, lr=LR, weight_decay=1e-4)
for i in xrange(0, MAXITER):
for i in xrange(0, MAXITER):
  optimizer.zero_grad()
  optimizer.zero_grad()
  train_loss = Variable(torch.zeros(1, 1))
  train_loss = Variable(torch.zeros(1, 1))
  y_hats_train = []
  y_hats_train = []
  for j in xrange(0, BATCH_SIZE):
  for j in xrange(0, BATCH_SIZE):
    sample_index = random.randint(0, 799) # TODO: sampling without replacement
    sample_index = random.randint(0, len(train_smiles) - 2)
    smile = train_smiles[sample_index]
    smile = train_smiles[sample_index]
    g, h = construct_multigraph(smile) # TODO: cache this
    g, h = construct_multigraph(smile) # TODO: cache this


    g2, h2 = construct_multigraph(smile)
    
    for k in xrange(0, T):
    for k in xrange(0, T):
      message_pass(g, h, k)
      message_pass(g, h, k)


    x = readout(h)
    x = readout(h, h2)
    #x = F.selu( fc(x) )
    y_hat = linear(x)
    y_hat = linear(x)
    y = train_labels[sample_index]
    y = train_labels[sample_index]


    y_hats_train.append(y_hat)
    y_hats_train.append(y_hat)


    error = (y_hat - y)*(y_hat - y)
    error = (y_hat - y)*(y_hat - y) / Variable(torch.FloatTensor([BATCH_SIZE])).view(1, 1)
    train_loss = train_loss + error
    train_loss = train_loss + error


  train_loss.backward()
  train_loss.backward()
  optimizer.step()
  optimizer.step()


  if i % 12 == 0:
  if i % int(len(train_smiles) / BATCH_SIZE) == 0:
    val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
    val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
    y_hats_val = []
    y_hats_val = []
    for j in xrange(0, len(val_smiles)):
    for j in xrange(0, len(val_smiles)):
      g, h = construct_multigraph(val_smiles[j])
      g, h = construct_multigraph(val_smiles[j])
      g2, h2 = construct_multigraph(val_smiles[j])


      for k in xrange(0, T):
      for k in xrange(0, T):
        message_pass(g, h, k)
        message_pass(g, h, k)


      x = readout(h)
      x = readout(h, h2)
      #x = F.selu( fc(x) )
      y_hat = linear(x)
      y_hat = linear(x)
      y = val_labels[j]
      y = val_labels[j]


      y_hats_val.append(y_hat)
      y_hats_val.append(y_hat)


      error = (y_hat - y)*(y_hat - y)
      error = (y_hat - y)*(y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
      val_loss = val_loss + error
      val_loss = val_loss + error


    y_hats_val = map(lambda x: x.data.numpy()[0], y_hats_val)
    y_hats_val = np.array(map(lambda x: x.data.numpy(), y_hats_val))
    y_val = map(lambda x: x.data.numpy()[0], val_labels)
    y_val = np.array(map(lambda x: x.data.numpy(), val_labels))
    r2_val = r2_score(y_val, y_hats_val)
    y_hats_val = y_hats_val.reshape(-1, 1)
    y_val = y_val.reshape(-1, 1)
    
    r2_val_old = r2_score(y_val, y_hats_val)
    r2_val_new = pearsonr(y_val, y_hats_val)[0]**2
  
  
    train_loss_ = train_loss.data.numpy()[0]
    train_loss_ = train_loss.data.numpy()[0]
    val_loss_ = val_loss.data.numpy()[0]
    val_loss_ = val_loss.data.numpy()[0]
    print 'epoch [%i/%i] train_loss [%f] val_loss [%f] r2_val [%s]' \
    print 'epoch [%i/%i] train_loss [%f] val_loss [%f] r2_val_old [%.4f], r2_val_new [%.4f]' \
                  % ((i + 1) / 12, maxiter_train / 12, train_loss_, val_loss_, r2_val)
                  % (num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new)

    num_epoch += 1
'''
train_labels = train_labels.data.numpy()
val_labels = val_labels.data.numpy()
  
train_mols = map(lambda x: Chem.MolFromSmiles(x), train_smiles)
train_fps = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in train_mols]
val_mols = map(lambda x: Chem.MolFromSmiles(x), val_smiles)
val_fps = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in val_mols]

np_fps_train = []
for fp in train_fps:
  arr = np.zeros((1,))
  DataStructs.ConvertToNumpyArray(fp, arr)
  np_fps_train.append(arr)

np_fps_val = []
for fp in val_fps:
  arr = np.zeros((1,))
  DataStructs.ConvertToNumpyArray(fp, arr)
  np_fps_val.append(arr)

rf = RandomForestRegressor(n_estimators=100, random_state=2)
#rf.fit(np_fps_train, train_labels)
#labels = rf.predict(val_fps)

ave = np.ones( (300,) )*(np.sum(val_labels) / 300.0)

print ave.shape
print val_labels.shape
r2 =  r2_score(ave, val_labels)
print 'rf r2 is:'
print r2
'''