Commit ff608daf authored by miaecle's avatar miaecle
Browse files

update DAG

parent b17d1282
Loading
Loading
Loading
Loading
+35 −32
Original line number Diff line number Diff line
@@ -410,6 +410,7 @@ class DAGLayer(Layer):
               activation='relu',
               dropout=None,
               max_atoms=50,
               batch_size=64,
               **kwargs):
    """
    Parameters
@@ -436,6 +437,7 @@ class DAGLayer(Layer):
    self.layer_sizes = layer_sizes
    self.dropout = dropout
    self.max_atoms = max_atoms
    self.batch_size = batch_size
    self.n_inputs = n_atom_feat + (self.max_atoms - 1) * n_graph_feat
    # number of inputs each step
    self.n_graph_feat = n_graph_feat
@@ -464,7 +466,7 @@ class DAGLayer(Layer):
    
  def _create_tensor(self):
    """description and explanation refer to deepchem.nn.DAGLayer
    parent layers: atom_features, parents, calculation_orders, membership
    parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms
    """
    # Add trainable weights
    self.build()
@@ -475,55 +477,55 @@ class DAGLayer(Layer):
    parents = self.in_layers[1].out_tensor
    # target atoms for each step: (batch_size*max_atoms) * max_atoms
    calculation_orders = self.in_layers[2].out_tensor
    membership = self.in_layers[3].out_tensor
    calculation_masks = self.in_layers[3].out_tensor

    n_atoms = atom_features.get_shape()[0]
    n_atoms = self.in_layers[4].out_tensor
    # initialize graph features for each graph
    graph_features_initial = tf.zeros((self.max_atoms*self.batch_size, self.max_atoms+1, self.n_graph_feat))
    # initialize graph features for each graph
    # another row of zeros is generated for padded dummy atoms
    graph_features = tf.Variable(
        tf.constant(0., shape=(n_atoms, self.max_atoms + 1, self.n_graph_feat)),
        graph_features_initial,
        trainable=False)
    
    # add dummy
    atom_features = tf.concat(
        axis=0,
        values=[
            atom_features, tf.constant(0., shape=(1, self.n_atom_feat))
        ])
    for count in range(self.max_atoms):
      # `count`-th step
      # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
      mask = calculation_masks[:, count]
      current_round = tf.boolean_mask(calculation_orders[:, count], mask)
      batch_atom_features = tf.gather(atom_features,
                                      calculation_orders[:, count])
                                      current_round)

      # generating index for graph features used in the inputs
      index = tf.stack(
          [
              tf.reshape(
                  tf.stack([tf.range(n_atoms)] * (self.max_atoms - 1), axis=1),
                  [-1]), tf.reshape(parents[:, count, 1:], [-1])
                  tf.stack([tf.boolean_mask(tf.range(n_atoms), mask)] * (self.max_atoms - 1), axis=1),
                  [-1]), tf.reshape(tf.boolean_mask(parents[:, count, 1:], mask), [-1])
          ],
          axis=1)
      # extracting graph features for parents of the target atoms, then flatten
      # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
      batch_graph_features = tf.reshape(
          tf.gather_nd(graph_features, index),
          [-1, (self.max_atoms - 1) * self.n_graph_feat])

      # concat into the input tensor: (batch_size*max_atoms) * n_inputs
      batch_inputs = tf.concat(
          axis=1, values=[batch_atom_features, batch_graph_features])
      # DAGgraph_step maps from batch_inputs to a batch of graph_features
      # of shape: (batch_size*max_atoms) * n_graph_features
      # representing the graph features of target atoms in each graph
      batch_outputs = self.DAGgraph_step(batch_inputs, self.W_list, self.b_list)

      # index for targe atoms
      target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
      # index for dummies
      target_index2 = tf.stack(
          [tf.range(n_atoms), tf.constant(self.max_atoms, shape=(n_atoms,))],
          axis=1)
      target_index = tf.boolean_mask(target_index, mask)
      # update the graph features for target atoms
      graph_features = tf.scatter_nd_update(graph_features, target_index,
                                            batch_outputs)
      # recover dummies to zeros if being updated
      graph_features = tf.scatter_nd_update(graph_features, target_index2,
                                            tf.zeros(
                                                (n_atoms, self.n_graph_feat)))

    # last step generates graph features for all target atoms
    # masking the outputs
    outputs = tf.multiply(batch_outputs,
                          tf.expand_dims(tf.to_float(membership), axis=1))
    self.out_tensor = outputs

    self.out_tensor = batch_outputs

  def DAGgraph_step(self, batch_inputs, W_list, b_list):
    outputs = batch_inputs
@@ -540,7 +542,7 @@ class DAGGather(Layer):
  def __init__(self,
               n_graph_feat=30,
               n_outputs=30,
               layer_sizes=[1000],
               layer_sizes=[100],
               init='glorot_uniform',
               activation='relu',
               dropout=None,
@@ -603,8 +605,9 @@ class DAGGather(Layer):

    # Extract atom_features
    atom_features = self.in_layers[0].out_tensor
    graph_features = tf.reshape(atom_features, [-1, self.max_atoms, self.n_graph_feat])
    graph_features = tf.reduce_sum(graph_features, axis=1)
    membership = self.in_layers[1].out_tensor
    # Extract atom_features
    graph_features = tf.segment_sum(atom_features, membership)
    # sum all graph outputs
    outputs = self.DAGgraph_step(graph_features, self.W_list, self.b_list)
    self.out_tensor = outputs
+32 −52
Original line number Diff line number Diff line
@@ -141,11 +141,11 @@ class WeaveTensorGraph(TensorGraph):

  def predict(self, dataset, transformers=[], batch_size=None):
    generator = self.default_generator(dataset, predict=True, pad_batches=False)
    return self.predict_on_generator(generator)
    return self.predict_on_generator(generator, transformers)

  def predict_proba(self, dataset, transformers=[], batch_size=None):
    generator = self.default_generator(dataset, predict=True, pad_batches=False)
    return self.predict_proba_on_generator(generator)
    return self.predict_proba_on_generator(generator, transformers)

  def predict_on_generator(self, generator, transformers=[]):
    retval = self.predict_proba_on_generator(generator, transformers)
@@ -290,11 +290,11 @@ class DTNNTensorGraph(TensorGraph):

  def predict(self, dataset, transformers=[], batch_size=None):
    generator = self.default_generator(dataset, predict=True, pad_batches=False)
    return self.predict_on_generator(generator)
    return self.predict_on_generator(generator, transformers)

  def predict_proba(self, dataset, transformers=[], batch_size=None):
    generator = self.default_generator(dataset, predict=True, pad_batches=False)
    return self.predict_proba_on_generator(generator)
    return self.predict_proba_on_generator(generator, transformers)

  def predict_on_generator(self, generator, transformers=[]):
    retval = self.predict_proba_on_generator(generator, transformers)
@@ -342,20 +342,23 @@ class DAGTensorGraph(TensorGraph):
    self.build_graph()

  def build_graph(self):
    self.atom_features = Feature(shape=(self.batch_size*self.max_atoms, self.n_atom_feat))
    self.parents = Feature(shape=(self.batch_size*self.max_atoms, self.max_atoms, self.max_atoms), dtype=tf.int32)
    self.calculation_orders = Feature(shape=(self.batch_size*self.max_atoms, self.max_atoms), dtype=tf.int32)
    self.membership = Feature(shape=(self.batch_size*self.max_atoms), dtype=tf.int32)
    self.atom_features = Feature(shape=(None, self.n_atom_feat))
    self.parents = Feature(shape=(None, self.max_atoms, self.max_atoms), dtype=tf.int32)
    self.calculation_orders = Feature(shape=(None, self.max_atoms), dtype=tf.int32)
    self.calculation_masks = Feature(shape=(None, self.max_atoms), dtype=tf.bool)
    self.membership = Feature(shape=(None,), dtype=tf.int32)
    self.n_atoms = Feature(shape=(), dtype=tf.int32)
    dag_layer1 = DAGLayer(
        n_graph_feat=self.n_graph_feat,
        n_atom_feat=self.n_atom_feat,
        max_atoms=self.max_atoms,
        in_layers=[self.atom_features, self.parents, self.calculation_orders, self.membership])
        batch_size=self.batch_size,
        in_layers=[self.atom_features, self.parents, self.calculation_orders, self.calculation_masks, self.n_atoms])
    dag_gather = DAGGather(
        n_graph_feat=self.n_graph_feat,
        n_outputs=self.n_outputs,
        max_atoms=self.max_atoms,
        in_layers=[dag_layer1])
        in_layers=[dag_layer1, self.membership])

    costs = []
    self.labels_fd = []
@@ -405,64 +408,41 @@ class DAGTensorGraph(TensorGraph):
          feed_dict[self.weights] = w_b
        
        atoms_per_mol = [mol.get_num_atoms() for mol in X_b]
        n_atom_features = X_b[0].get_atom_features().shape[1]
        membership = np.concatenate(
            [
                np.array([1] * n_atoms + [0] * (self.max_atoms - n_atoms))
                for n_atoms in atoms_per_mol
            ],
            axis=0)
        n_atoms = sum(atoms_per_mol)
        start_index = [0] + list(np.cumsum(atoms_per_mol)[:-1])

        atoms_all = []
        # calculation orders for a batch of molecules
        parents_all = []
        calculation_orders = []
        calculation_masks = []
        membership = []
        for idm, mol in enumerate(X_b):
          atom_features_padded = np.concatenate(
              [
                  mol.get_atom_features(), np.zeros(
                      (self.max_atoms - atoms_per_mol[idm], n_atom_features))
              ],
              axis=0)
          atoms_all.append(atom_features_padded)
    
          # padding atom features vector of each molecule with 0
          atoms_all.append(mol.get_atom_features())
          parents = mol.parents
          assert len(parents) == atoms_per_mol[idm]
          parents_all.extend(parents[:])
          parents_all.extend([
              self.max_atoms * np.ones((self.max_atoms, self.max_atoms), dtype=int)
              for i in range(self.max_atoms - atoms_per_mol[idm])
          ])
          for parent in parents:
            calculation_orders.append(self.index_changing(parent[:, 0], idm))
    
          calculation_orders.extend([
              self.batch_size * self.max_atoms * np.ones(
                  (self.max_atoms,), dtype=int)
              for i in range(self.max_atoms - atoms_per_mol[idm])
          ])
          parents_all.extend(parents)
          calculation_index = np.array(parents)[:, :, 0]
          mask = np.array(calculation_index-self.max_atoms, dtype=bool)
          calculation_orders.append(calculation_index + start_index[idm])
          calculation_masks.append(mask)
          membership.extend([idm]*atoms_per_mol[idm])
    
        feed_dict[self.atom_features] = np.concatenate(atoms_all, axis=0)
        feed_dict[self.parents] = np.stack(parents_all, axis=0)
        feed_dict[self.calculation_orders] = np.stack(calculation_orders, axis=0)
        feed_dict[self.membership] = membership
        feed_dict[self.calculation_orders] = np.concatenate(calculation_orders, axis=0)
        feed_dict[self.calculation_masks] = np.concatenate(calculation_masks, axis=0)
        feed_dict[self.membership] = np.array(membership)
        feed_dict[self.n_atoms] = n_atoms
        yield feed_dict

  def index_changing(self, index, n_mol):
    output = np.zeros_like(index)
    for ide, element in enumerate(index):
      if element < self.max_atoms:
        output[ide] = element + n_mol * self.max_atoms
      else:
        output[ide] = self.batch_size * self.max_atoms
    return output

  def predict(self, dataset, transformers=[], batch_size=None):
    generator = self.default_generator(dataset, predict=True, pad_batches=False)
    return self.predict_on_generator(generator)
    return self.predict_on_generator(generator, transformers)

  def predict_proba(self, dataset, transformers=[], batch_size=None):
    generator = self.default_generator(dataset, predict=True, pad_batches=False)
    return self.predict_proba_on_generator(generator)
    return self.predict_proba_on_generator(generator, transformers)

  def predict_on_generator(self, generator, transformers=[]):
    retval = self.predict_proba_on_generator(generator, transformers)
+5 −5
Original line number Diff line number Diff line
@@ -126,21 +126,19 @@ class SequentialDAGGraph(SequentialGraph):
  """SequentialGraph for DAG models
  """

  def __init__(self, n_feat, batch_size=50, max_atoms=50):
  def __init__(self, n_atom_feat=75, max_atoms=50):
    """
    Parameters
    ----------
    n_feat: int
    n_atom_feat: int
      Number of features per atom.
    batch_size: int, optional(default=50)
      Number of molecules in a batch
    max_atoms: int, optional(default=50)
      Maximum number of atoms in a molecule, should be defined based on dataset
    """
    self.graph = tf.Graph()
    with self.graph.as_default():
      self.graph_topology = DAGGraphTopology(
          n_feat, batch_size, max_atoms=max_atoms)
          n_atom_feat=n_atom_feat, max_atoms=max_atoms)
      self.output = self.graph_topology.get_atom_features_placeholder()
    self.layers = []

@@ -150,6 +148,8 @@ class SequentialDAGGraph(SequentialGraph):
      if type(layer).__name__ in ['DAGLayer']:
        self.output = layer([self.output] +
                            self.graph_topology.get_topology_placeholders())
      elif type(layer).__name__ in ['DAGGather']:
        self.output = layer([self.output, self.graph_topology.membership_placeholder])
      else:
        self.output = layer(self.output)
      self.layers.append(layer)
+40 −59
Original line number Diff line number Diff line
@@ -266,39 +266,50 @@ class DAGGraphTopology(GraphTopology):
  """GraphTopology for DAG models
  """

  def __init__(self, n_feat, batch_size, name='topology', max_atoms=50):
  def __init__(self, n_atom_feat=75, max_atoms=50, name='topology'):

    self.n_feat = n_feat
    self.name = name
    self.n_atom_feat = n_atom_feat
    self.max_atoms = max_atoms
    self.batch_size = batch_size
    self.name = name
    self.atom_features_placeholder = tf.placeholder(
        dtype='float32',
        shape=(self.batch_size * self.max_atoms, self.n_feat),
        shape=(None, self.n_atom_feat),
        name=self.name + '_atom_features')

    self.parents_placeholder = tf.placeholder(
        dtype='int32',
        shape=(self.batch_size * self.max_atoms, self.max_atoms,
        shape=(None, self.max_atoms,
               self.max_atoms),
        # molecule * atom(graph) => step => features
        name=self.name + '_parents')

    self.calculation_orders_placeholder = tf.placeholder(
        dtype='int32',
        shape=(self.batch_size * self.max_atoms, self.max_atoms),
        shape=(None, self.max_atoms),
        # molecule * atom(graph) => step
        name=self.name + '_orders')

    self.calculation_masks_placeholder = tf.placeholder(
        dtype='bool',
        shape=(None, self.max_atoms),
        # molecule * atom(graph) => step
        name=self.name + '_masks')
    
    self.membership_placeholder = tf.placeholder(
        dtype='int32',
        shape=(self.batch_size * self.max_atoms),
        shape=(None,),
        name=self.name + '_membership')
    
    self.n_atoms_placeholder = tf.placeholder(
        dtype='int32',
        shape=(),
        name=self.name + '_n_atoms')
    
    # Define the list of tensors to be used as topology
    self.topology = [
        self.parents_placeholder, self.calculation_orders_placeholder,
        self.membership_placeholder
        self.calculation_masks_placeholder, self.membership_placeholder,
        self.n_atoms_placeholder
    ]

    self.inputs = [self.atom_features_placeholder]
@@ -328,73 +339,43 @@ class DAGGraphTopology(GraphTopology):
    """

    atoms_per_mol = [mol.get_num_atoms() for mol in batch]
    n_atom_features = batch[0].get_atom_features().shape[1]
    membership = np.concatenate(
        [
            np.array([1] * n_atoms + [0] * (self.max_atoms - n_atoms))
            for i, n_atoms in enumerate(atoms_per_mol)
        ],
        axis=0)
    n_atoms = sum(atoms_per_mol)
    start_index = [0] + list(np.cumsum(atoms_per_mol)[:-1])

    atoms_all = []
    # calculation orders for a batch of molecules
    parents_all = []
    calculation_orders = []
    calculation_masks = []
    membership = []
    for idm, mol in enumerate(batch):
      # padding atom features vector of each molecule with 0
      atom_features_padded = np.concatenate(
          [
              mol.get_atom_features(), np.zeros(
                  (self.max_atoms - atoms_per_mol[idm], n_atom_features))
          ],
          axis=0)
      atoms_all.append(atom_features_padded)

      # calculation orders for DAGs
      atoms_all.append(mol.get_atom_features())
      parents = mol.parents
      # number of DAGs should equal number of atoms
      assert len(parents) == atoms_per_mol[idm]
      parents_all.extend(parents[:])
      # padding with `max_atoms`
      parents_all.extend([
          self.max_atoms * np.ones((self.max_atoms, self.max_atoms), dtype=int)
          for i in range(self.max_atoms - atoms_per_mol[idm])
      ])
      for parent in parents:
        # index for an atom in `parents_all` and `atoms_all` is different, 
        # this function changes the index from the position in current molecule(DAGs, `parents_all`) 
        # to position in batch of molecules(`atoms_all`)
        # only used in tf.gather on `atom_features_placeholder`
        calculation_orders.append(self.index_changing(parent[:, 0], idm))

      # padding with `batch_size*max_atoms`
      calculation_orders.extend([
          self.batch_size * self.max_atoms * np.ones(
              (self.max_atoms,), dtype=int)
          for i in range(self.max_atoms - atoms_per_mol[idm])
      ])
      parents_all.extend(parents)
      calculation_index = np.array(parents)[:, :, 0]
      mask = np.array(calculation_index-self.max_atoms, dtype=bool)
      calculation_orders.append(calculation_index + start_index[idm])
      calculation_masks.append(mask)
      membership.extend([idm]*atoms_per_mol[idm])

    atoms_all = np.concatenate(atoms_all, axis=0)
    parents_all = np.stack(parents_all, axis=0)
    calculation_orders = np.stack(calculation_orders, axis=0)
    calculation_orders = np.concatenate(calculation_orders, axis=0)
    calculation_masks = np.concatenate(calculation_masks, axis=0)
    membership = np.array(membership)
    
    atoms_dict = {
        self.atom_features_placeholder: atoms_all,
        self.membership_placeholder: membership,
        self.parents_placeholder: parents_all,
        self.calculation_orders_placeholder: calculation_orders
        self.calculation_orders_placeholder: calculation_orders,
        self.calculation_masks_placeholder: calculation_masks,
        self.membership_placeholder: membership,
        self.n_atoms_placeholder: n_atoms
    }

    return atoms_dict

  def index_changing(self, index, n_mol):
    output = np.zeros_like(index)
    for ide, element in enumerate(index):
      if element < self.max_atoms:
        output[ide] = element + n_mol * self.max_atoms
      else:
        output[ide] = self.batch_size * self.max_atoms
    return output


class WeaveGraphTopology(GraphTopology):
  """Manages placeholders associated with batch of graphs and their topology"""
+25 −42
Original line number Diff line number Diff line
@@ -919,7 +919,6 @@ class DTNNStep(Layer):
    self.build()
    atom_features = x[0]
    distance = x[1]
    atom_membership = x[2]
    distance_membership_i = x[3]
    distance_membership_j = x[4]
    distance_hidden = tf.matmul(distance, self.W_df) + self.b_df
@@ -1016,19 +1015,20 @@ class DAGLayer(Layer):

  def __init__(self,
               n_graph_feat=30,
               n_atom_features=75,
               n_atom_feat=75,
               layer_sizes=[100],
               init='glorot_uniform',
               activation='relu',
               dropout=None,
               max_atoms=50,
               batch_size=64,
               **kwargs):
    """
    Parameters
    ----------
    n_graph_feat: int
      Number of features for each node(and the whole grah).
    n_atom_features: int
    n_atom_feat: int
      Number of features listed per atom.
    layer_sizes: list of int, optional(default=[1000])
      Structure of hidden layer(s)
@@ -1048,11 +1048,12 @@ class DAGLayer(Layer):
    self.layer_sizes = layer_sizes
    self.dropout = dropout
    self.max_atoms = max_atoms
    self.n_inputs = n_atom_features + (self.max_atoms - 1) * n_graph_feat
    self.batch_size = batch_size
    self.n_inputs = n_atom_feat + (self.max_atoms - 1) * n_graph_feat
    # number of inputs each step
    self.n_graph_feat = n_graph_feat
    self.n_outputs = n_graph_feat
    self.n_atom_features = n_atom_features
    self.n_atom_feat = n_atom_feat

  def build(self):
    """"Construct internal trainable weights.
@@ -1093,51 +1094,43 @@ class DAGLayer(Layer):
    """
    # Add trainable weights
    self.build()

    # Extract atom_features
    # Basic features of every atom: (batch_size*max_atoms) * n_atom_features
    atom_features = x[0]

    # calculation orders of graph: (batch_size*max_atoms) * max_atoms * max_atoms
    # each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
    # each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
    # step i calculates the graph features for atoms of index `parents[:,i,0]`
    parents = x[1]

    # target atoms for each step: (batch_size*max_atoms) * max_atoms
    # represent the same atoms of `parents[:, :, 0]`, 
    # different in that these index are positions in `atom_features`
    # paded with max_atoms*batch_size
    calculation_orders = x[2]
    # flags: (batch_size*max_atoms)
    # 0 for paddings, 1 for real atoms
    membership = x[3]
    calculation_masks = x[3]
    # number of atoms in total, should equal `batch_size*max_atoms`
    n_atoms = atom_features.get_shape()[0]
    n_atoms = x[5]
    
    graph_features_initial = tf.zeros((self.max_atoms*self.batch_size, self.max_atoms+1, self.n_graph_feat))
    # initialize graph features for each graph
    # another row of zeros is generated for padded dummy atoms
    graph_features = tf.Variable(
        tf.constant(0., shape=(n_atoms, self.max_atoms + 1, self.n_graph_feat)),
        graph_features_initial,
        trainable=False)
    # add dummy
    atom_features = tf.concat(
        axis=0,
        values=[
            atom_features, tf.constant(0., shape=(1, self.n_atom_features))
        ])
    
    for count in range(self.max_atoms):
      # `count`-th step
      # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
      mask = calculation_masks[:, count]
      current_round = tf.boolean_mask(calculation_orders[:, count], mask)
      batch_atom_features = tf.gather(atom_features,
                                      calculation_orders[:, count])
                                      current_round)

      # generating index for graph features used in the inputs
      index = tf.stack(
          [
              tf.reshape(
                  tf.stack([tf.range(n_atoms)] * (self.max_atoms - 1), axis=1),
                  [-1]), tf.reshape(parents[:, count, 1:], [-1])
                  tf.stack([tf.boolean_mask(tf.range(n_atoms), mask)] * (self.max_atoms - 1), axis=1),
                  [-1]), tf.reshape(tf.boolean_mask(parents[:, count, 1:], mask), [-1])
          ],
          axis=1)
      # extracting graph features for parents of the target atoms, then flatten
@@ -1156,23 +1149,13 @@ class DAGLayer(Layer):

      # index for targe atoms
      target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
      # index for dummies
      target_index2 = tf.stack(
          [tf.range(n_atoms), tf.constant(self.max_atoms, shape=(n_atoms,))],
          axis=1)
      target_index = tf.boolean_mask(target_index, mask)
      # update the graph features for target atoms
      graph_features = tf.scatter_nd_update(graph_features, target_index,
                                            batch_outputs)
      # recover dummies to zeros if being updated
      graph_features = tf.scatter_nd_update(graph_features, target_index2,
                                            tf.zeros(
                                                (n_atoms, self.n_graph_feat)))

    # last step generates graph features for all target atoms
    # masking the outputs
    outputs = tf.multiply(batch_outputs,
                          tf.expand_dims(tf.to_float(membership), axis=1))
    return outputs

    # last step generates graph features for all target atom
    return batch_outputs

  def DAGgraph_step(self, batch_inputs, W_list, b_list):
    outputs = batch_inputs
@@ -1190,7 +1173,7 @@ class DAGGather(Layer):
  def __init__(self,
               n_graph_feat=30,
               n_outputs=30,
               layer_sizes=[1000],
               layer_sizes=[100],
               init='glorot_uniform',
               activation='relu',
               dropout=None,
@@ -1262,10 +1245,10 @@ class DAGGather(Layer):
    """
    # Add trainable weights
    self.build()

    atom_features = x[0]
    membership = x[1]
    # Extract atom_features
    graph_features = tf.reshape(x, [-1, self.max_atoms, self.n_graph_feat])
    graph_features = tf.reduce_sum(graph_features, axis=1)
    graph_features = tf.segment_sum(atom_features, membership)
    # sum all graph outputs
    outputs = self.DAGgraph_step(graph_features, self.W_list, self.b_list)
    return outputs
Loading