Commit 9fd41988 authored by miaecle's avatar miaecle
Browse files

temp save

parent 91489e46
Loading
Loading
Loading
Loading
+32 −2
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ __license__ = "MIT"

import tensorflow as tf
from deepchem.nn.layers import GraphGather
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology, WeaveGraphTopology
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology, WeaveGraphTopology, WeaveGraphTopology_v2


class SequentialGraph(object):
@@ -195,6 +195,36 @@ class SequentialWeaveGraph(SequentialGraph):
        self.output = layer(self.output)
      self.layers.append(layer)

class SequentialWeaveGraph_v2(SequentialGraph):
  """SequentialGraph for Weave models
  """

  def __init__(self, batch_size, max_atoms=50, n_atom_feat=75, n_pair_feat=14):
    self.graph = tf.Graph()
    self.batch_size = batch_size
    self.max_atoms = max_atoms
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat
    with self.graph.as_default():
      self.graph_topology = WeaveGraphTopology_v2(self.batch_size, self.max_atoms, 
                                                  self.n_atom_feat,self.n_pair_feat)
      self.output = self.graph_topology.get_atom_features_placeholder()
      self.output_P = self.graph_topology.get_pair_features_placeholder()
    self.layers = []

  def add(self, layer):
    """Adds a new layer to model."""
    with self.graph.as_default():
      if type(layer).__name__ in ['WeaveLayer_v2']:
        self.output, self.output_P = layer([
            self.output, self.output_P
        ] + self.graph_topology.get_topology_placeholders())
      elif type(layer).__name__ in ['WeaveGather_v2']:
        self.output = layer(
            [self.output, self.graph_topology.atom_split_placeholder])
      else:
        self.output = layer(self.output)
      self.layers.append(layer)


class SequentialSupportGraph(object):
+111 −0
Original line number Diff line number Diff line
@@ -492,3 +492,114 @@ class WeaveGraphTopology(GraphTopology):
        self.membership_placeholder: membership
    }
    return dict_DTNN

class WeaveGraphTopology_v2(GraphTopology):
  """Manages placeholders associated with batch of graphs and their topology"""

  def __init__(self, batch_size, max_atoms, n_atom_feat, n_pair_feat,
               name='Weave_topology'):
    """
    Parameters
    ----------
    max_atoms: int
      maximum number of atoms in a molecule
    n_atom_feat: int
      number of basic features of each atom
    n_pair_feat: int
      number of basic features of each pair
    """

    #self.n_atoms = n_atoms
    self.name = name
    self.batch_size = batch_size
    self.max_atoms = max_atoms * batch_size
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat

    self.atom_features_placeholder = tf.placeholder(
        dtype='float32',
        shape=(None, self.n_atom_feat),
        name=self.name + '_atom_features')
    self.pair_features_placeholder = tf.placeholder(
        dtype='float32',
        shape=(None, self.n_pair_feat),
        name=self.name + '_pair_features')
    self.pair_split_placeholder = tf.placeholder(
        dtype='int32', shape=(self.max_atoms,), 
        name=self.name + '_pair_split')
    self.pair_membership_placeholder = tf.placeholder(
        dtype='bool', shape=(self.max_atoms,), 
        name=self.name + '_pair_membership')
    self.atom_split_placeholder = tf.placeholder(
        dtype='int32', shape=(self.batch_size,), 
        name=self.name + '_atom_split')
    self.atom_to_pair_placeholder = tf.placeholder(
        dtype='int32', shape=(None,2), 
        name=self.name + '_atom_to_pair')
    
    # Define the list of tensors to be used as topology
    self.topology = [self.pair_split_placeholder, self.pair_membership_placeholder,
                     self.atom_split_placeholder, self.atom_to_pair_placeholder]
    self.inputs = [self.atom_features_placeholder]
    self.inputs += self.topology

  def get_pair_features_placeholder(self):
    return self.pair_features_placeholder

  def batch_to_feed_dict(self, batch):
    """Converts the current batch of WeaveMol into tensorflow feed_dict.

    Assigns the atom features and pair features to the
    placeholders tensors

    params
    ------
    batch : np.ndarray
      Array of WeaveMol

    returns
    -------
    feed_dict : dict
      Can be merged with other feed_dicts for input into tensorflow
    """
    # Extract atom numbers
    atom_feat = []
    pair_feat = []
    atom_split = []
    atom_to_pair = []
    pair_split = []
    max_atoms = self.max_atoms
    start = 0
    for im, mol in enumerate(batch):
      n_atoms = mol.get_num_atoms()
      # number of atoms in each molecule
      atom_split.append(n_atoms)
      # index of pair features
      C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
      atom_to_pair.append(np.transpose(np.array([C1.flatten()+start, C0.flatten()+start])))
      start = start + n_atoms
      # number of pairs for each atom
      pair_split.extend([n_atoms]*n_atoms)
      # atom features
      atom_feat.append(mol.get_atom_features())
      # pair features
      pair_feat.append(np.reshape(mol.get_pair_features(), 
                                  (n_atoms*n_atoms, self.n_pair_feat)))
      
    atom_feat = np.concatenate(atom_feat, axis=0)
    pair_feat = np.concatenate(pair_feat, axis=0)
    atom_to_pair = np.concatenate(atom_to_pair, axis=0)
    atom_split = np.array(atom_split)
    n_pair = len(pair_split)
    pair_split = np.pad(pair_split, ((0, max_atoms-n_pair)), 'constant')
    pair_membership = np.array([True]*n_pair + [False]*(max_atoms-n_pair))
    # Generate dicts
    dict_DTNN = {
        self.atom_features_placeholder: atom_feat,
        self.pair_features_placeholder: pair_feat,
        self.pair_split_placeholder: pair_split,
        self.pair_membership_placeholder: pair_membership,
        self.atom_split_placeholder: atom_split,
        self.atom_to_pair_placeholder: atom_to_pair
    }
    return dict_DTNN
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ hps['dag'] = {
hps['weave'] = {
    'batch_size': 64,
    'nb_epoch': 40,
    'learning_rate': 0.001,
    'learning_rate': 0.0001,
    'n_graph_feat': 128,
    'n_pair_feat': 14,
    'seed': 123
+4 −0
Original line number Diff line number Diff line
@@ -21,8 +21,10 @@ from deepchem.nn.layers import DAGLayer
from deepchem.nn.layers import DAGGather

from deepchem.nn.weave_layers import WeaveLayer
from deepchem.nn.weave_layers import WeaveLayer_v2
from deepchem.nn.weave_layers import WeaveConcat
from deepchem.nn.weave_layers import WeaveGather
from deepchem.nn.weave_layers import WeaveGather_v2

from deepchem.nn.model_ops import weight_decay
from deepchem.nn.model_ops import optimizer
@@ -36,8 +38,10 @@ from deepchem.models.tf_new_models.graph_topology import GraphTopology
from deepchem.models.tf_new_models.graph_topology import DTNNGraphTopology
from deepchem.models.tf_new_models.graph_topology import DAGGraphTopology
from deepchem.models.tf_new_models.graph_topology import WeaveGraphTopology
from deepchem.models.tf_new_models.graph_topology import WeaveGraphTopology_v2
from deepchem.models.tf_new_models.graph_models import SequentialGraph
from deepchem.models.tf_new_models.graph_models import SequentialDTNNGraph
from deepchem.models.tf_new_models.graph_models import SequentialDAGGraph
from deepchem.models.tf_new_models.graph_models import SequentialWeaveGraph
from deepchem.models.tf_new_models.graph_models import SequentialWeaveGraph_v2
from deepchem.models.tf_new_models.graph_models import SequentialSupportGraph
+97 −1
Original line number Diff line number Diff line
@@ -175,6 +175,66 @@ class WeaveLayer(Layer):
    return A, P


class WeaveLayer_v2(WeaveLayer):
  def call(self, x, mask=None):
    """Execute this layer on input tensors.

    x = [atom_features, pair_features, pair_split, pair_membership, atom_split]
    
    Parameters
    ----------
    x: list
      list of Tensors of form described above.
    mask: bool, optional
      Ignored. Present only to shadow superclass call() method.

    Returns
    -------
    A: Tensor
      Tensor of atom_features
    P: Tensor
      Tensor of pair_features
    """
    # Add trainable weights
    self.build()

    atom_features = x[0]
    pair_features = x[1]

    pair_split = x[2]
    pair_membership = x[3]
    atom_split = x[4]
    atom_to_pair = x[5]

    AA = tf.matmul(atom_features, self.W_AA) + self.b_AA
    AA = self.activation(AA)
    PA = tf.matmul(pair_features, self.W_PA) + self.b_PA
    PA = self.activation(PA)
    PAs = tf.split(PA, pair_split, axis=0)
    PA = [tf.reduce_sum(molecule, 0) for molecule in PAs]
    PA = tf.boolean_mask(PA, pair_membership)
    
    A = tf.matmul(tf.concat([AA, PA], 1), self.W_A) + self.b_A
    A = self.activation(A)
      
    if self.update_pair:
      AP_ij = tf.matmul(tf.reshape(tf.gather(atom_features, atom_to_pair), 
                                   [-1, 2*self.n_atom_input_feat]), self.W_AP) + self.b_AP
      AP_ij = self.activation(AP_ij)
      AP_ji = tf.matmul(tf.reshape(tf.gather(atom_features, tf.reverse(atom_to_pair, [1])), 
                                   [-1, 2*self.n_atom_input_feat]), self.W_AP) + self.b_AP
      AP_ji = self.activation(AP_ji)
      
      PP = tf.matmul(pair_features, self.W_PP) + self.b_PP
      PP = self.activation(PP)
      P = tf.matmul(tf.concat([AP_ij + AP_ji, PP], 1), self.W_P) + self.b_P
      P = self.activation(P)
    else:
      P = pair_features
      
    return A, P
    
    
class WeaveConcat(Layer):
  """" Concat a batch of molecules into a batch of atoms
  """
@@ -342,3 +402,39 @@ class WeaveGather(Layer):
    outputs = outputs / tf.reduce_sum(outputs, axis=2, keep_dims=True)
    outputs = tf.reshape(outputs, [-1, self.n_input * 11])
    return outputs
    
class WeaveGather_v2(WeaveGather):
  def call(self, x, mask=None):
    """Execute this layer on input tensors.

    x = [atom_features, atom_split]
    
    Parameters
    ----------
    x: list
      Tensors as listed above
    mask: bool, optional
      Ignored. Present only to shadow superclass call() method.

    Returns
    -------
    outputs: Tensor
      Tensor of molecular features
    """
    # Add trainable weights
    self.build()
    outputs = x[0]
    atom_split = x[1]

    if self.gaussian_expand:
      outputs = self.gaussian_histogram(outputs)

    outputs = tf.split(outputs, atom_split, axis=0)

    output_molecules = [tf.reduce_sum(molecule, 0) for molecule in outputs]

    output_molecules = tf.stack(output_molecules)
    if self.gaussian_expand:
      output_molecules = tf.matmul(output_molecules, self.W) + self.b
      output_molecules = self.activation(output_molecules)
    return output_molecules
 No newline at end of file
Loading