Commit 8488e279 authored by miaecle's avatar miaecle
Browse files

Merge remote-tracking branch 'remotes/mine/weave' into weave

parents 94489f71 882feccf
Loading
Loading
Loading
Loading
+7 −5
Original line number Diff line number Diff line
@@ -161,6 +161,7 @@ class SequentialDAGGraph(SequentialGraph):
        self.output = layer(self.output)
      self.layers.append(layer)


class SequentialWeaveGraph(SequentialGraph):
  """SequentialGraph for Weave models
  """
@@ -171,8 +172,7 @@ class SequentialWeaveGraph(SequentialGraph):
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat
    with self.graph.as_default():
      self.graph_topology = WeaveGraphTopology(self.max_atoms, 
                                               self.n_atom_feat,
      self.graph_topology = WeaveGraphTopology(self.max_atoms, self.n_atom_feat,
                                               self.n_pair_feat)
      self.output = self.graph_topology.get_atom_features_placeholder()
      self.output_P = self.graph_topology.get_pair_features_placeholder()
@@ -182,8 +182,9 @@ class SequentialWeaveGraph(SequentialGraph):
    """Adds a new layer to model."""
    with self.graph.as_default():
      if type(layer).__name__ in ['WeaveLayer']:
        self.output, self.output_P = layer([self.output, self.output_P] +
                            self.graph_topology.get_topology_placeholders())
        self.output, self.output_P = layer([
            self.output, self.output_P
        ] + self.graph_topology.get_topology_placeholders())
      elif type(layer).__name__ in ['WeaveConcat']:
        self.output = layer(
            [self.output, self.graph_topology.atom_mask_placeholder])
@@ -194,6 +195,7 @@ class SequentialWeaveGraph(SequentialGraph):
        self.output = layer(self.output)
      self.layers.append(layer)


class SequentialSupportGraph(object):
  """An analog of Keras Sequential model for test/support models."""

+15 −15
Original line number Diff line number Diff line
@@ -393,10 +393,12 @@ class DAGGraphTopology(GraphTopology):
        output[ide] = self.batch_size * self.max_atoms
    return output


class WeaveGraphTopology(GraphTopology):
  """Manages placeholders associated with batch of graphs and their topology"""

  def __init__(self, max_atoms, n_atom_feat, n_pair_feat, name='Weave_topology'):
  def __init__(self, max_atoms, n_atom_feat, n_pair_feat,
               name='Weave_topology'):
    """
    Parameters
    ----------
@@ -431,9 +433,7 @@ class WeaveGraphTopology(GraphTopology):
        shape=(None, self.max_atoms, self.max_atoms),
        name=self.name + '_pair_mask')
    self.membership_placeholder = tf.placeholder(
        dtype='int32',
        shape=(None,),
        name=self.name + '_membership')
        dtype='int32', shape=(None,), name=self.name + '_membership')
    # Define the list of tensors to be used as topology
    self.topology = [self.atom_mask_placeholder, self.pair_mask_placeholder]
    self.inputs = [self.atom_features_placeholder]
@@ -467,14 +467,14 @@ class WeaveGraphTopology(GraphTopology):
    max_atoms = self.max_atoms
    for im, mol in enumerate(batch):
      n_atoms = mol.get_num_atoms()
      atom_feat.append(np.pad(mol.get_atom_features(), ((0, max_atoms - n_atoms), 
                                                        (0,0)), 
                              'constant'))
      atom_mask.append(np.array([1]*n_atoms + [0]*(max_atoms-n_atoms), dtype=float))
      pair_feat.append(np.pad(mol.get_pair_features(), ((0, max_atoms - n_atoms),
                                                        (0, max_atoms - n_atoms),
                                                        (0,0)),
      atom_feat.append(
          np.pad(mol.get_atom_features(), ((0, max_atoms - n_atoms), (0, 0)),
                 'constant'))
      atom_mask.append(
          np.array([1] * n_atoms + [0] * (max_atoms - n_atoms), dtype=float))
      pair_feat.append(
          np.pad(mol.get_pair_features(), ((0, max_atoms - n_atoms), (
              0, max_atoms - n_atoms), (0, 0)), 'constant'))
      pair_mask.append(np.array([[1]*n_atoms + [0]*(max_atoms-n_atoms)]*n_atoms + \
                       [[0]*max_atoms]*(max_atoms-n_atoms), dtype=float))
      membership.extend([im] * n_atoms)