Commit fd7a6c87 authored by leswing's avatar leswing
Browse files

Merge branch 'master' into 660-pdb_rf

parents a5842c4a 6dd2f1b3
Loading
Loading
Loading
Loading
+21 −10
Original line number Original line Diff line number Diff line
# Message Passing Neural Networks
# Message Passing Neural Networks


MPNNs aim to generalize molecular machine learning models that operate on graph-valued inputs. Graph-Convolutions [https://arxiv.org/abs/1509.09292] and Weaves [https://arxiv.org/abs/1603.00856] (among others) can be recast into this framework [https://arxiv.org/abs/1704.01212]
MPNNs aim to generalize molecular machine learning models that operate on graph-valued inputs. Graph-Convolutions [https://arxiv.org/abs/1509.09292] and Weaves \
[https://arxiv.org/abs/1603.00856] (among others) can be recast into this framework [https://arxiv.org/abs/1704.01212]


The premise is that the featurization of arbitrary chemical multigraphs can be broken down into a message function, vertex-update function, and a readout function that is invariant to graph isomorphisms. All functions must be subdifferentiable to preserve gradient-flow and ideally are learnable too.
The premise is that the featurization of arbitrary chemical multigraphs can be broken down into a message function, vertex-update function, and a readout functi\
on that is invariant to graph isomorphisms. All functions must be subdifferentiable to preserve gradient-flow and ideally are learnable as well


Models of this style introduce an additional parameter **T**, which is the number of iterations for the message-passing stage. Values greater than 4 don't seem to improve performance.
Models of this style introduce an additional parameter **T**, which is the number of iterations for the message-passing stage. Values greater than 4 don't seem \
to improve performance.


Requires PyTorch.
##MPNN-S Variant
 MPNNs do provide a nice mathematical framework that can capture modern molecular machine learning algorithms we work with today. One criticism of this algorithm class is that training is slow, due to the sheer number of training iterations required for convergence - at batch size 20 on QM9, the MPNN authors trained for 540 epochs.
 
 
| Dataset | Examples | MP-DNN Val R2 (Index Split) |
This can be improved significantly by using batch normalization, or more interestingly, the new SELU activation [https://arxiv.org/pdf/1706.02515.pdf]. In order to use SELUs straight through the system, we dropped the GRU unit [https://arxiv.org/pdf/1412.3555.pdf] the authors used in favor of a SELU activated fully-connected neural network for each time step **T**. This modified approach now achieves peak performance in as little as 60 epochs on most molecular machine learning datasets.
| ------ | ------ | ------ |
| Delaney | 1102 | .801 |


## Running Code
MPNN-S sets new records on the Delaney & PPB datasets:

| Dataset | Num Examples | MP-DNN Val R2 [Scaffold Split] | GraphConv Val R2 [Scaffold Split] |
| ------ | ------ | ------ | ------ |
| Delaney | 1102 | **.820** | .606 |
| PPB | 1600 | **.427** | .381 |
| Clearance | 838 | **.32** | .28 |


## Run Code
```sh
```sh
$ python mpnn_baseline.py
$ python mpnn.py
```
```


License
License

contrib/mpnn/donkey.py

0 → 100644
+101 −0
Original line number Original line Diff line number Diff line
# 2017 DeepCrystal Technologies - Patrick Hop
#
# Data loading a splitting file
#
# MIT License - have fun!!
# ===========================================================

import os
import random
from collections import OrderedDict

import deepchem as dc
from deepchem.utils import ScaffoldGenerator
from deepchem.utils.save import log
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from sklearn import preprocessing
from sklearn.decomposition import TruncatedSVD

from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem

random.seed(2)
np.random.seed(2)
torch.manual_seed(2)

def generate_scaffold(smiles, include_chirality=False):
  """Compute the Bemis-Murcko scaffold for a SMILES string."""
  mol = Chem.MolFromSmiles(smiles)
  engine = ScaffoldGenerator(include_chirality=include_chirality)
  scaffold = engine.get_scaffold(mol)
  return scaffold

def split(dataset,
          frac_train=.80,
          frac_valid=.10,
          frac_test=.10,
          log_every_n=1000):
  """
  Splits internal compounds into train/validation/test by scaffold.
  """
  np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
  scaffolds = {}
  log("About to generate scaffolds", True)
  data_len = len(dataset)
  
  for ind, smiles in enumerate(dataset):
    if ind % log_every_n == 0:
      log("Generating scaffold %d/%d" % (ind, data_len), True)
    scaffold = generate_scaffold(smiles)
    if scaffold not in scaffolds:
      scaffolds[scaffold] = [ind]
    else:
      scaffolds[scaffold].append(ind)

  scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
  scaffold_sets = [
    scaffold_set
    for (scaffold, scaffold_set) in sorted(
        scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
  ]
  train_cutoff = frac_train * len(dataset)
  valid_cutoff = (frac_train + frac_valid) * len(dataset)
  train_inds, valid_inds, test_inds = [], [], []
  log("About to sort in scaffold sets", True)
  for scaffold_set in scaffold_sets:
    if len(train_inds) + len(scaffold_set) > train_cutoff:
      if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
        test_inds += scaffold_set
      else:
        valid_inds += scaffold_set
    else:
      train_inds += scaffold_set
  return train_inds, valid_inds, test_inds

def load_dataset(filename, whiten=False):
  f = open(filename, 'r')
  features = []
  labels = []
  tracer = 0
  for line in f:
    if tracer == 0:
      tracer += 1
      continue
    splits =  line[:-1].split(',')
    features.append(splits[-1])
    labels.append(float(splits[-2]))
  features = np.array(features)
  labels = np.array(labels, dtype='float32').reshape(-1, 1)

  train_ind, val_ind, test_ins = split(features)

  train_features = np.take(features, train_ind)
  train_labels = np.take(labels, train_ind)
  val_features = np.take(features, val_ind)
  val_labels = np.take(labels, val_ind)
  
  return train_features, train_labels, val_features, val_labels
+182 −0
Original line number Original line Diff line number Diff line
# 2017 DeepCrystal Technologies - Patrick Hop
# 2017 DeepCrystal Technologies - Patrick Hop
#
#
# Message Passing Neural Network for Chemical Multigraphs
# Message Passing Neural Network SELU [MPNN-S] for Chemical Multigraphs
#
#
# MIT License - have fun!!
# MIT License - have fun!!
# ===========================================================
# ===========================================================


import math

import deepchem as dc
import deepchem as dc
from rdkit import Chem, DataStructs
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import AllChem
@@ -17,81 +19,71 @@ import torch.nn.functional as F


from sklearn.metrics import r2_score
from sklearn.metrics import r2_score
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn import preprocessing
import numpy as np
import numpy as np


import random
import random
from collections import OrderedDict
from collections import OrderedDict
from scipy.stats import pearsonr

import donkey


random.seed(2)
random.seed(2)
torch.manual_seed(2)
torch.manual_seed(2)
np.random.seed(2)
np.random.seed(2)


T = 4
DATASET = 'az_ppb.csv'
BATCH_SIZE = 64
print(DATASET)
MAXITER = 2000

T = 3
BATCH_SIZE = 48
MAXITER = 40000
LIMIT = 0
LR = 5e-4


#A = {}
R = nn.Linear(150, 128)
# valid_bonds = {'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}
U = {0: nn.Linear(156, 75), 1: nn.Linear(156, 75), 2: nn.Linear(156, 75)}
#for valid_bond in valid_bonds:
V = {0: nn.Linear(75, 75), 1: nn.Linear(75, 75), 2: nn.Linear(75, 75)}
#  A[valid_bond] = nn.Linear(75, 75)
E = nn.Linear(6, 6)


R = nn.Linear(75, 128)
def adjust_learning_rate(optimizer, epoch):
#GRU = nn.GRU(150, 75, 1)
  """Sets the learning rate to the initial LR decayed by .8 every 5 epochs"""
U = nn.Linear(150, 75)
  lr = LR * (0.9 ** (epoch // 10))
  print('new lr [%.5f]' % lr)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr


def load_dataset():
def load_dataset():
  f = open('delaney-processed.csv', 'r')
  train_features, train_labels, val_features, val_labels = donkey.load_dataset(DATASET)
  features = []

  labels = []
  scaler = preprocessing.StandardScaler().fit(train_labels)
  tracer = 0
  train_labels = scaler.transform(train_labels)
  for line in f:
  val_labels = scaler.transform(val_labels)
    if tracer == 0:
      tracer += 1
      continue
    splits =  line[:-1].split(',')
    features.append(splits[-1])
    labels.append(float(splits[-2]))

  train_features = np.array(features[:900])
  train_labels = np.array(labels[:900])
  val_features = np.array(features[900:1100])
  val_labels = np.array(labels[900:1100])


  train_labels = Variable(torch.FloatTensor(train_labels), requires_grad=False)
  train_labels = Variable(torch.FloatTensor(train_labels), requires_grad=False)
  val_labels = Variable(torch.FloatTensor(val_labels), requires_grad=False)
  val_labels = Variable(torch.FloatTensor(val_labels), requires_grad=False)
  
  
  return train_features, train_labels, val_features, val_labels
  return train_features, train_labels, val_features, val_labels


def readout(h):
def readout(h, h2):
  reads = map(lambda x: F.relu(R(h[x])), h.keys())
  catted_reads = map(lambda x: torch.cat([h[x[0]], h2[x[1]]], 1), zip(h2.keys(), h.keys()))
  activated_reads = map(lambda x: F.selu( R(x) ), catted_reads)
  readout = Variable(torch.zeros(1, 128))
  readout = Variable(torch.zeros(1, 128))
  for read in reads:
  for read in activated_reads:
    readout = readout + read
    readout = readout + read
  return readout
  return F.tanh( readout )


def message_pass(g, h, k):
def message_pass(g, h, k):
  #flow_delta = Variable(torch.zeros(1, 1))
  #h_t = Variable(torch.zeros(1, 1, 75))
  for v in g.keys():
  for v in g.keys():
    neighbors = g[v]
    neighbors = g[v]
    for neighbor in neighbors:
    for neighbor in neighbors:
      e_vw = neighbor[0]
      e_vw = neighbor[0] # feature variable
      w = neighbor[1]
      w = neighbor[1]
      #bond_type = e_vw.GetBondType()
      #A_vw = A[str(e_vw.GetBondType())]

      m_v = h[w]
      catted = torch.cat([h[v], m_v], 1)
      #gru_act, h_t = GRU(catted.view(1, 1, 150), h_t)
      
      
      # measure convergence
      m_w = V[k](h[w])
      #pdist = nn.PairwiseDistance(2)
      m_e_vw = E(e_vw)
      #flow_delta = flow_delta + torch.sum(pdist(gru_act.view(1, 75), h[v]))
      reshaped = torch.cat( (h[v], m_w, m_e_vw), 1)
      
      h[v] = F.selu(U[k](reshaped))
      #h[v] = gru_act.view(1, 75)
      h[v] = U(catted)

  #print '    flow delta [%i] [%f]' % (k, flow_delta.data.numpy()[0])


def construct_multigraph(smile):
def construct_multigraph(smile):
  g = OrderedDict({})
  g = OrderedDict({})
@@ -104,6 +96,8 @@ def construct_multigraph(smile):
    for j in xrange(0, molecule.GetNumAtoms()):
    for j in xrange(0, molecule.GetNumAtoms()):
      e_ij = molecule.GetBondBetweenAtoms(i, j)
      e_ij = molecule.GetBondBetweenAtoms(i, j)
      if e_ij != None:
      if e_ij != None:
        e_ij =  map(lambda x: 1 if x == True else 0, dc.feat.graph_features.bond_features(e_ij)) # ADDED edge feat
        e_ij = Variable(torch.FloatTensor(e_ij).view(1, 6))
        atom_j = molecule.GetAtomWithIdx(j)
        atom_j = molecule.GetAtomWithIdx(j)
        if i not in g:
        if i not in g:
          g[i] = []
          g[i] = []
@@ -113,99 +107,76 @@ def construct_multigraph(smile):


train_smiles, train_labels, val_smiles, val_labels = load_dataset()
train_smiles, train_labels, val_smiles, val_labels = load_dataset()


# training loop
linear = nn.Linear(128, 1)
linear = nn.Linear(128, 1)
params = [#{'params': A['SINGLE'].parameters()},
params = [{'params': R.parameters()},
         #{'params': A['DOUBLE'].parameters()},
         {'params': U[0].parameters()},
         #{'params': A['TRIPLE'].parameters()},
         {'params': U[1].parameters()},
         #{'params': A['AROMATIC'].parameters()},
         {'params': U[2].parameters()},
         {'params': R.parameters()},
         {'params': E.parameters()},
         #{'params': GRU.parameters()},
         {'params': V[0].parameters()},
         {'params': U.parameters()},
         {'params': V[1].parameters()},
         {'params': V[2].parameters()},
         {'params': linear.parameters()}]
         {'params': linear.parameters()}]


optimizer = optim.SGD(params, lr=1e-5, momentum=0.9)
num_epoch = 0
optimizer = optim.Adam(params, lr=LR, weight_decay=1e-4)
for i in xrange(0, MAXITER):
for i in xrange(0, MAXITER):
  optimizer.zero_grad()
  optimizer.zero_grad()
  train_loss = Variable(torch.zeros(1, 1))
  train_loss = Variable(torch.zeros(1, 1))
  y_hats_train = []
  y_hats_train = []
  for j in xrange(0, BATCH_SIZE):
  for j in xrange(0, BATCH_SIZE):
    sample_index = random.randint(0, 799) # TODO: sampling without replacement
    sample_index = random.randint(0, len(train_smiles) - 2)
    smile = train_smiles[sample_index]
    smile = train_smiles[sample_index]
    g, h = construct_multigraph(smile) # TODO: cache this
    g, h = construct_multigraph(smile) # TODO: cache this


    g2, h2 = construct_multigraph(smile)
    
    for k in xrange(0, T):
    for k in xrange(0, T):
      message_pass(g, h, k)
      message_pass(g, h, k)


    x = readout(h)
    x = readout(h, h2)
    #x = F.selu( fc(x) )
    y_hat = linear(x)
    y_hat = linear(x)
    y = train_labels[sample_index]
    y = train_labels[sample_index]


    y_hats_train.append(y_hat)
    y_hats_train.append(y_hat)


    error = (y_hat - y)*(y_hat - y)
    error = (y_hat - y)*(y_hat - y) / Variable(torch.FloatTensor([BATCH_SIZE])).view(1, 1)
    train_loss = train_loss + error
    train_loss = train_loss + error


  train_loss.backward()
  train_loss.backward()
  optimizer.step()
  optimizer.step()


  if i % 12 == 0:
  if i % int(len(train_smiles) / BATCH_SIZE) == 0:
    val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
    val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
    y_hats_val = []
    y_hats_val = []
    for j in xrange(0, len(val_smiles)):
    for j in xrange(0, len(val_smiles)):
      g, h = construct_multigraph(val_smiles[j])
      g, h = construct_multigraph(val_smiles[j])
      g2, h2 = construct_multigraph(val_smiles[j])


      for k in xrange(0, T):
      for k in xrange(0, T):
        message_pass(g, h, k)
        message_pass(g, h, k)


      x = readout(h)
      x = readout(h, h2)
      #x = F.selu( fc(x) )
      y_hat = linear(x)
      y_hat = linear(x)
      y = val_labels[j]
      y = val_labels[j]


      y_hats_val.append(y_hat)
      y_hats_val.append(y_hat)


      error = (y_hat - y)*(y_hat - y)
      error = (y_hat - y)*(y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
      val_loss = val_loss + error
      val_loss = val_loss + error


    y_hats_val = map(lambda x: x.data.numpy()[0], y_hats_val)
    y_hats_val = np.array(map(lambda x: x.data.numpy(), y_hats_val))
    y_val = map(lambda x: x.data.numpy()[0], val_labels)
    y_val = np.array(map(lambda x: x.data.numpy(), val_labels))
    r2_val = r2_score(y_val, y_hats_val)
    y_hats_val = y_hats_val.reshape(-1, 1)
    y_val = y_val.reshape(-1, 1)
    
    r2_val_old = r2_score(y_val, y_hats_val)
    r2_val_new = pearsonr(y_val, y_hats_val)[0]**2
  
  
    train_loss_ = train_loss.data.numpy()[0]
    train_loss_ = train_loss.data.numpy()[0]
    val_loss_ = val_loss.data.numpy()[0]
    val_loss_ = val_loss.data.numpy()[0]
    print 'epoch [%i/%i] train_loss [%f] val_loss [%f] r2_val [%s]' \
    print 'epoch [%i/%i] train_loss [%f] val_loss [%f] r2_val_old [%.4f], r2_val_new [%.4f]' \
                  % ((i + 1) / 12, maxiter_train / 12, train_loss_, val_loss_, r2_val)
                  % (num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new)

    num_epoch += 1
'''
train_labels = train_labels.data.numpy()
val_labels = val_labels.data.numpy()
  
train_mols = map(lambda x: Chem.MolFromSmiles(x), train_smiles)
train_fps = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in train_mols]
val_mols = map(lambda x: Chem.MolFromSmiles(x), val_smiles)
val_fps = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in val_mols]

np_fps_train = []
for fp in train_fps:
  arr = np.zeros((1,))
  DataStructs.ConvertToNumpyArray(fp, arr)
  np_fps_train.append(arr)

np_fps_val = []
for fp in val_fps:
  arr = np.zeros((1,))
  DataStructs.ConvertToNumpyArray(fp, arr)
  np_fps_val.append(arr)

rf = RandomForestRegressor(n_estimators=100, random_state=2)
#rf.fit(np_fps_train, train_labels)
#labels = rf.predict(val_fps)

ave = np.ones( (300,) )*(np.sum(val_labels) / 300.0)

print ave.shape
print val_labels.shape
r2 =  r2_score(ave, val_labels)
print 'rf r2 is:'
print r2
'''
+33 −7
Original line number Original line Diff line number Diff line
@@ -27,6 +27,9 @@ class Layer(object):
    self.in_layers = in_layers
    self.in_layers = in_layers
    self.op_type = "gpu"
    self.op_type = "gpu"
    self.variable_scope = ''
    self.variable_scope = ''
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []


  def _get_layer_number(self):
  def _get_layer_number(self):
    class_name = self.__class__.__name__
    class_name = self.__class__.__name__
@@ -395,16 +398,35 @@ class GRU(Layer):
      raise ValueError("Must have one parent")
      raise ValueError("Must have one parent")
    parent_tensor = inputs[0]
    parent_tensor = inputs[0]
    gru_cell = tf.contrib.rnn.GRUCell(self.n_hidden)
    gru_cell = tf.contrib.rnn.GRUCell(self.n_hidden)
    initial_gru_state = gru_cell.zero_state(self.batch_size, tf.float32)
    zero_state = gru_cell.zero_state(self.batch_size, tf.float32)
    out_tensor, rnn_states = tf.nn.dynamic_rnn(
    if set_tensors:
        gru_cell,
      initial_state = tf.placeholder(tf.float32, zero_state.get_shape())
        parent_tensor,
    else:
        initial_state=initial_gru_state,
      initial_state = zero_state
        scope=self.name)
    out_tensor, final_state = tf.nn.dynamic_rnn(
        gru_cell, parent_tensor, initial_state=initial_state, scope=self.name)
    if set_tensors:
    if set_tensors:
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
      self.out_tensor = out_tensor
      self.rnn_initial_states.append(initial_state)
      self.rnn_final_states.append(final_state)
      self.rnn_zero_states.append(np.zeros(zero_state.get_shape(), np.float32))
    return out_tensor
    return out_tensor


  def none_tensors(self):
    saved_tensors = [
        self.out_tensor, self.rnn_initial_states, self.rnn_final_states,
        self.rnn_zero_states
    ]
    self.out_tensor = None
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    return saved_tensors

  def set_tensors(self, tensor):
    self.out_tensor, self.rnn_initial_states, self.rnn_final_states, self.rnn_zero_states = tensor



class TimeSeriesDense(Layer):
class TimeSeriesDense(Layer):


@@ -745,6 +767,7 @@ class Conv2D(Layer):
               stride=1,
               stride=1,
               padding='SAME',
               padding='SAME',
               activation_fn=tf.nn.relu,
               activation_fn=tf.nn.relu,
               normalizer_fn=None,
               scope_name=None,
               scope_name=None,
               **kwargs):
               **kwargs):
    """Create a Conv2D layer.
    """Create a Conv2D layer.
@@ -765,12 +788,15 @@ class Conv2D(Layer):
      the padding method to use, either 'SAME' or 'VALID'
      the padding method to use, either 'SAME' or 'VALID'
    activation_fn: object
    activation_fn: object
      the Tensorflow activation function to apply to the output
      the Tensorflow activation function to apply to the output
    normalizer_fn: object
      the Tensorflow normalizer function to apply to the output
    """
    """
    self.num_outputs = num_outputs
    self.num_outputs = num_outputs
    self.kernel_size = kernel_size
    self.kernel_size = kernel_size
    self.stride = stride
    self.stride = stride
    self.padding = padding
    self.padding = padding
    self.activation_fn = activation_fn
    self.activation_fn = activation_fn
    self.normalizer_fn = normalizer_fn
    super(Conv2D, self).__init__(**kwargs)
    super(Conv2D, self).__init__(**kwargs)
    if scope_name is None:
    if scope_name is None:
      scope_name = self.name
      scope_name = self.name
@@ -786,7 +812,7 @@ class Conv2D(Layer):
        stride=self.stride,
        stride=self.stride,
        padding=self.padding,
        padding=self.padding,
        activation_fn=self.activation_fn,
        activation_fn=self.activation_fn,
        normalizer_fn=tf.contrib.layers.batch_norm,
        normalizer_fn=self.normalizer_fn,
        scope=self.scope_name)
        scope=self.scope_name)
    out_tensor = out_tensor
    out_tensor = out_tensor
    if set_tensors:
    if set_tensors:
+19 −0
Original line number Original line Diff line number Diff line
@@ -94,6 +94,10 @@ class TensorGraph(Model):
    self.save_file = "%s/%s" % (self.model_dir, "model")
    self.save_file = "%s/%s" % (self.model_dir, "model")
    self.model_class = None
    self.model_class = None


    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []

  def _add_layer(self, layer):
  def _add_layer(self, layer):
    if layer.name is None:
    if layer.name is None:
      layer.name = "%s_%s" % (layer.__class__.__name__, len(self.layers) + 1)
      layer.name = "%s_%s" % (layer.__class__.__name__, len(self.layers) + 1)
@@ -226,6 +230,9 @@ class TensorGraph(Model):
          feed_dict[self.features[0]] = X_b
          feed_dict[self.features[0]] = X_b
        if len(self.task_weights) == 1 and w_b is not None and not predict:
        if len(self.task_weights) == 1 and w_b is not None and not predict:
          feed_dict[self.task_weights[0]] = w_b
          feed_dict[self.task_weights[0]] = w_b
        for (inital_state, zero_state) in zip(self.rnn_initial_states,
                                              self.rnn_zero_states):
          feed_dict[initial_state] = zero_state
        yield feed_dict
        yield feed_dict


  def predict_on_generator(self, generator, transformers=[]):
  def predict_on_generator(self, generator, transformers=[]):
@@ -328,6 +335,9 @@ class TensorGraph(Model):
        with tf.name_scope(node):
        with tf.name_scope(node):
          node_layer = self.layers[node]
          node_layer = self.layers[node]
          node_layer.create_tensor(training=self._training_placeholder)
          node_layer.create_tensor(training=self._training_placeholder)
          self.rnn_initial_states += node_layer.rnn_initial_states
          self.rnn_final_states += node_layer.rnn_final_states
          self.rnn_zero_states += node_layer.rnn_zero_states
      self.built = True
      self.built = True


    for layer in self.layers.values():
    for layer in self.layers.values():
@@ -412,7 +422,13 @@ class TensorGraph(Model):
    # Remove out_tensor from the object to be pickled
    # Remove out_tensor from the object to be pickled
    must_restore = False
    must_restore = False
    tensor_objects = self.tensor_objects
    tensor_objects = self.tensor_objects
    rnn_initial_states = self.rnn_initial_states
    rnn_final_states = self.rnn_final_states
    rnn_zero_states = self.rnn_zero_states
    self.tensor_objects = {}
    self.tensor_objects = {}
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    out_tensors = []
    out_tensors = []
    if self.built:
    if self.built:
      must_restore = True
      must_restore = True
@@ -440,6 +456,9 @@ class TensorGraph(Model):
      self._training_placeholder = training_placeholder
      self._training_placeholder = training_placeholder
      self.built = True
      self.built = True
    self.tensor_objects = tensor_objects
    self.tensor_objects = tensor_objects
    self.rnn_initial_states = rnn_initial_states
    self.rnn_final_states = rnn_final_states
    self.rnn_zero_states = rnn_zero_states


  def evaluate_generator(self,
  def evaluate_generator(self,
                         feed_dict_generator,
                         feed_dict_generator,
Loading