Commit 06ecfb87 authored by miaecle's avatar miaecle
Browse files

rebuild DAG

parent 8233320d
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
@@ -706,6 +706,50 @@ class TestOverfit(test_util.TensorFlowTestCase):

    assert scores[regression_metric.name] > .9

  def test_DAG_singletask_regression_overfit(self):
    """Test DAG regressor multitask overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)
    n_tasks = 1

    # Load mini log-solubility dataset.
    featurizer = dc.feat.ConvMolFeaturizer()
    tasks = ["outcome"]
    input_file = os.path.join(self.current_dir, "example_regression.csv")
    loader = dc.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)
    dataset = loader.featurize(input_file)

    regression_metric = dc.metrics.Metric(
        dc.metrics.pearson_r2_score, task_averager=np.mean)

    n_feat = 75
    batch_size = 10

    graph = dc.nn.SequentialDAGGraph(
        n_feat, batch_size=batch_size, max_atoms=50)
    graph.add(dc.nn.DAGLayer(30, n_feat, max_atoms=50))
    graph.add(dc.nn.DAGGather(max_atoms=50))

    model = dc.models.MultitaskGraphRegressor(
        graph,
        n_tasks,
        n_feat,
        batch_size=batch_size,
        learning_rate=0.005,
        learning_rate_decay_time=1000,
        optimizer_type="adam",
        beta1=.9,
        beta2=.999)

    # Fit trained model
    model.fit(dataset, nb_epoch=50)
    model.save()
    # Eval model on train
    scores = model.evaluate(dataset, [regression_metric])

    assert scores[regression_metric.name] > .9

  def test_siamese_singletask_classification_overfit(self):
    """Test siamese singletask model overfits tiny data."""
    np.random.seed(123)
+34 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ __license__ = "MIT"

import tensorflow as tf
from deepchem.nn.layers import GraphGather
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology


class SequentialGraph(object):
@@ -129,6 +129,39 @@ class SequentialDTNNGraph(SequentialGraph):
      self.layers.append(layer)


class SequentialDAGGraph(SequentialGraph):
  """SequentialGraph for DAG models
  """

  def __init__(self, n_feat, batch_size=50, max_atoms=50):
    """
    Parameters
    ----------
    n_feat: int
      Number of features per atom.
    batch_size: int, optional(default=50)
      Number of molecules in a batch
    max_atoms: int, optional(default=50)
      Maximum number of atoms in a molecule, should be defined based on dataset
    """
    self.graph = tf.Graph()
    with self.graph.as_default():
      self.graph_topology = DAGGraphTopology(
          n_feat, batch_size, max_atoms=max_atoms)
      self.output = self.graph_topology.get_atom_features_placeholder()
    self.layers = []

  def add(self, layer):
    """Adds a new layer to model."""
    with self.graph.as_default():
      if type(layer).__name__ in ['DAGLayer']:
        self.output = layer([self.output] +
                            self.graph_topology.get_topology_placeholders())
      else:
        self.output = layer(self.output)
      self.layers.append(layer)


class SequentialSupportGraph(object):
  """An analog of Keras Sequential model for test/support models."""

+173 −0
Original line number Diff line number Diff line
@@ -258,3 +258,176 @@ class DTNNGraphTopology(GraphTopology):
    steps = np.array([distance_min + i * step_size for i in range(n_distance)])
    distance_vector = np.exp(-np.square(distance - steps) / (2 * step_size**2))
    return distance_vector


class DAGGraphTopology(GraphTopology):
  """GraphTopology for DAG models
  """

  def __init__(self, n_feat, batch_size, name='topology', max_atoms=50):

    self.n_feat = n_feat
    self.name = name
    self.max_atoms = max_atoms
    self.batch_size = batch_size
    self.atom_features_placeholder = tf.placeholder(
        dtype='float32',
        shape=(self.batch_size * self.max_atoms, self.n_feat),
        name=self.name + '_atom_features')

    self.parents_placeholder = tf.placeholder(
        dtype='int32',
        shape=(self.batch_size * self.max_atoms, self.max_atoms,
               self.max_atoms),
        # molecule * atom(graph) => step => features
        name=self.name + '_parents')

    self.calculation_orders_placeholder = tf.placeholder(
        dtype='int32',
        shape=(self.batch_size * self.max_atoms, self.max_atoms),
        # molecule * atom(graph) => step
        name=self.name + '_orders')

    self.membership_placeholder = tf.placeholder(
        dtype='int32',
        shape=(self.batch_size * self.max_atoms),
        name=self.name + '_membership')

    # Define the list of tensors to be used as topology
    self.topology = [
        self.parents_placeholder, self.calculation_orders_placeholder,
        self.membership_placeholder
    ]

    self.inputs = [self.atom_features_placeholder]
    self.inputs += self.topology

  def get_parents_placeholder(self):
    return self.parents_placeholder

  def get_calculation_orders_placeholder(self):
    return self.calculation_orders_placeholder

  def batch_to_feed_dict(self, batch):
    """Converts the current batch of mol_graphs into tensorflow feed_dict.

    Assigns the graph information in array of ConvMol objects to the
    placeholders tensors for DAG models

    params
    ------
    batch : np.ndarray
      Array of ConvMol objects

    returns
    -------
    feed_dict : dict
      Can be merged with other feed_dicts for input into tensorflow
    """
    # Merge mol conv objects

    atoms_per_mol = [mol.get_num_atoms() for mol in batch]
    n_atom_features = batch[0].get_atom_features().shape[1]
    membership = np.concatenate(
        [
            np.array([1] * n_atoms + [0] * (self.max_atoms - n_atoms))
            for i, n_atoms in enumerate(atoms_per_mol)
        ],
        axis=0)

    atoms_all = []
    parents_all = []
    calculation_orders = []
    for idm, mol in enumerate(batch):
      atom_features_padded = np.concatenate(
          [
              mol.get_atom_features(), np.zeros(
                  (self.max_atoms - atoms_per_mol[idm], n_atom_features))
          ],
          axis=0)
      # padding atom features vector of each molecule with 0
      atoms_all.append(atom_features_padded)

      parents = self.UG_to_DAG(mol)
      # ConvMol objects input here should have gone through the DAG Transformer
      assert len(parents) == atoms_per_mol[idm]
      parents_all.extend(parents[:])
      parents_all.extend([
          self.max_atoms * np.ones((self.max_atoms, self.max_atoms), dtype=int)
          for i in range(self.max_atoms - atoms_per_mol[idm])
      ])
      # padding with max_atoms
      for parent in parents:
        calculation_orders.append(self.indice_changing(parent[:, 0], idm))
        # change the indice from current molecule to batch of molecules
      calculation_orders.extend([
          self.batch_size * self.max_atoms * np.ones(
              (self.max_atoms,), dtype=int)
          for i in range(self.max_atoms - atoms_per_mol[idm])
      ])
      # padding with batch_size * max_atoms

    atoms_all = np.concatenate(atoms_all, axis=0)
    parents_all = np.stack(parents_all, axis=0)
    calculation_orders = np.stack(calculation_orders, axis=0)
    atoms_dict = {
        self.atom_features_placeholder: atoms_all,
        self.membership_placeholder: membership,
        self.parents_placeholder: parents_all,
        self.calculation_orders_placeholder: calculation_orders
    }

    return atoms_dict

  def indice_changing(self, indice, n_mol):
    output = np.zeros_like(indice)
    for ide, element in enumerate(indice):
      if element < self.max_atoms:
        output[ide] = element + n_mol * self.max_atoms
      else:
        output[ide] = self.batch_size * self.max_atoms
    return output

  def UG_to_DAG(self, sample):
    parents = []
    UG = sample.get_adjacency_list()
    n_atoms = sample.get_num_atoms()
    max_atoms = self.max_atoms
    for count in range(n_atoms):
      DAG = []
      parent = [[] for i in range(n_atoms)]
      current_atoms = [count]
      # first element is current atom
      atoms_indicator = np.ones((n_atoms,))
      # if is been included in the graph
      atoms_indicator[count] = 0
      radial = 0
      while np.sum(atoms_indicator) > 0:
        if radial > n_atoms:
          break  # molecules with two separate ions may stuck here
        next_atoms = []
        for current_atom in current_atoms:
          for atom_adj in UG[current_atom]:
            # atoms connected to current_atom
            if atoms_indicator[atom_adj] > 0:
              DAG.append((current_atom, atom_adj))
              atoms_indicator[atom_adj] = 0
              # tagging for included atoms
              next_atoms.append(atom_adj)
        current_atoms = next_atoms
        # into next step, finding atoms connected with one more bond
        radial = radial + 1
      for edge in reversed(DAG):
        parent[edge[0]].append(edge[1])
        parent[edge[0]].extend(parent[edge[1]])
        # adding parents
      for ids, atom in enumerate(parent):
        parent[ids].insert(0, ids)
      parent = sorted(parent, key=len)
      for ids, atom in enumerate(parent):
        n_par = len(atom)
        parent[ids].extend([max_atoms for i in range(max_atoms - n_par)])
      while len(parent) < max_atoms:
        parent.insert(0, [max_atoms] * max_atoms)
      parents.append(np.array(parent))
    return parents
+4 −0
Original line number Diff line number Diff line
@@ -17,6 +17,8 @@ from deepchem.nn.layers import ResiLSTMEmbedding
from deepchem.nn.layers import DTNNEmbedding
from deepchem.nn.layers import DTNNStep
from deepchem.nn.layers import DTNNGather
from deepchem.nn.layers import DAGLayer
from deepchem.nn.layers import DAGGather

from deepchem.nn.model_ops import weight_decay
from deepchem.nn.model_ops import optimizer
@@ -28,6 +30,8 @@ from deepchem.nn.objectives import mean_squared_error

from deepchem.models.tf_new_models.graph_topology import GraphTopology
from deepchem.models.tf_new_models.graph_topology import DTNNGraphTopology
from deepchem.models.tf_new_models.graph_topology import DAGGraphTopology
from deepchem.models.tf_new_models.graph_models import SequentialGraph
from deepchem.models.tf_new_models.graph_models import SequentialDTNNGraph
from deepchem.models.tf_new_models.graph_models import SequentialDAGGraph
from deepchem.models.tf_new_models.graph_models import SequentialSupportGraph
+267 −0
Original line number Diff line number Diff line
@@ -992,3 +992,270 @@ class DTNNGather(Layer):
        tf.multiply(output, tf.expand_dims(atom_mask, axis=2)), axis=1)

    return output


class DAGLayer(Layer):
  """" Main layer of DAG model
  For a molecule with n atoms, n different graphs are generated and run through
  The final outputs of each graph become the graph features of corresponding
  atom, which will be summed and put into another network in DAGGather Layer
  """

  def __init__(self,
               n_graph_feat=30,
               n_atom_features=75,
               layer_sizes=[100],
               init='glorot_uniform',
               activation='relu',
               dropout=None,
               max_atoms=50,
               **kwargs):
    """
    Parameters
    ----------
    n_graph_feat: int
      Number of features for each node(and the whole grah).
    n_atom_features: int
      Number of features listed per atom.
    layer_sizes: list of int, optional(default=[1000])
      Structure of hidden layer(s)
    init: str, optional
      Weight initialization for filters.
    activation: str, optional
      Activation function applied
    dropout: float, optional
      Dropout probability, not supported here
    max_atoms: int, optional
      Maximum number of atoms in molecules.
    """
    super(DAGLayer, self).__init__(**kwargs)

    self.init = initializations.get(init)  # Set weight initialization
    self.activation = activations.get(activation)  # Get activations
    self.layer_sizes = layer_sizes
    self.dropout = dropout
    self.max_atoms = max_atoms
    self.n_inputs = n_atom_features + (self.max_atoms - 1) * n_graph_feat
    # number of inputs each step
    self.n_graph_feat = n_graph_feat
    self.n_outputs = n_graph_feat
    self.n_atom_features = n_atom_features

  def build(self):
    """"Construct internal trainable weights.
    """

    self.W_list = []
    self.b_list = []
    prev_layer_size = self.n_inputs
    for layer_size in self.layer_sizes:
      self.W_list.append(self.init([prev_layer_size, layer_size]))
      self.b_list.append(model_ops.zeros(shape=[
          layer_size,
      ]))
      prev_layer_size = layer_size
    self.W_list.append(self.init([prev_layer_size, self.n_outputs]))
    self.b_list.append(model_ops.zeros(shape=[
        self.n_outputs,
    ]))

    self.trainable_weights = self.W_list + self.b_list

  def call(self, x, mask=None):
    """Execute this layer on input tensors.

    x = [atom_features, parents, calculation_orders, membership]
    
    Parameters
    ----------
    x: list
      list of Tensors of form described above.
    mask: bool, optional
      Ignored. Present only to shadow superclass call() method.

    Returns
    -------
    outputs: tf.Tensor
      Tensor of atom features, of shape (n_atoms, n_graph_feat)
    """
    # Add trainable weights
    self.build()

    # Extract atom_features
    atom_features = x[0]
    # Basic features of every atom: (batch_size*max_atoms) * n_atom_features

    parents = x[1]
    # Structure of graph: (batch_size*max_atoms) * max_atoms * max_atoms
    # each atom corresponds to a graph, which is represented by the max_atoms * max_atoms int32 matrix of indices
    # there are in total max_atoms number of steps(corresponding to rows) in calculating the graph outputs
    # in step i, we calculate the graph features of atom(i,0)
    # from inputs: atom features of atom(i,0), graph_features of this atom's parents in the graph(atom(i,1) through atom(i,max_atoms))
    # if number of parents is less than max_atoms-1, padded it with max_atoms, representing a dummy with all zeros)
    calculation_orders = x[2]
    # (batch_size*max_atoms) * max_atoms
    # indices of atom(i,0)
    # represent the same atoms of parents[:, :, 0], different in that the indices are for atom_features(0~max_atoms*batch_size)
    # paded with max_atoms*batch_size
    membership = x[3]
    # (batch_size*max_atoms)
    # 0 for dummy atoms, 1 for real atoms
    n_atoms = atom_features.get_shape()[0]
    # number of atoms in total, =batch_size
    graph_features = tf.Variable(
        tf.constant(0., shape=(n_atoms, self.max_atoms + 1, self.n_graph_feat)),
        trainable=False)
    # Initialize graph features for atoms in the molecule for each graph
    # for each graph, another row of zeros is generated as the dummy
    atom_features = tf.concat(
        axis=0,
        values=[
            atom_features, tf.constant(0., shape=(1, self.n_atom_features))
        ])
    # dummy
    for count in range(self.max_atoms):
      # count-th step
      batch_atom_features = tf.gather(atom_features,
                                      calculation_orders[:, count])
      # extracting atom features of target atoms, shape: (batch_size*max_atoms) * n_atom_features

      indice = tf.stack(
          [
              tf.reshape(
                  tf.stack([tf.range(n_atoms)] * (self.max_atoms - 1), axis=1),
                  [-1]), tf.reshape(parents[:, count, 1:], [-1])
          ],
          axis=1)
      # generating indices for graph features used in the inputs
      batch_graph_features = tf.reshape(
          tf.gather_nd(graph_features, indice),
          [-1, (self.max_atoms - 1) * self.n_graph_feat])
      # extracting graph features of the parents of the target atoms, then flatten
      # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
      batch_inputs = tf.concat(
          axis=1, values=[batch_atom_features, batch_graph_features])
      # concat into the input tensor, shape: (batch_size*max_atoms) * n_inputs
      batch_outputs = self.DAGgraph_step(batch_inputs, self.W_list, self.b_list)
      # DAGgraph_step mapping from batch_inputs to a batch of graph_features
      # shape: (batch_size*max_atoms) * n_graph_features
      # representing the graph features of the target atoms in each graph
      target_indices = tf.stack(
          [tf.range(n_atoms), parents[:, count, 0]], axis=1)
      target_indices2 = tf.stack(
          [tf.range(n_atoms), tf.constant(self.max_atoms, shape=(n_atoms,))],
          axis=1)
      graph_features = tf.scatter_nd_update(graph_features, target_indices,
                                            batch_outputs)
      # update the graph features for target atoms
      graph_features = tf.scatter_nd_update(graph_features, target_indices2,
                                            tf.zeros(
                                                (n_atoms, self.n_graph_feat)))
      # recover dummies to zeros if being updated

    outputs = tf.multiply(batch_outputs,
                          tf.expand_dims(tf.to_float(membership), axis=1))
    # masking the outputs of the last step
    return outputs

  def DAGgraph_step(self, batch_inputs, W_list, b_list):
    outputs = batch_inputs
    for idw, W in enumerate(W_list):
      outputs = tf.nn.xw_plus_b(outputs, W, b_list[idw])
      outputs = self.activation(outputs)
    return outputs


class DAGGather(Layer):
  """ Gather layer of DAG model
  for each molecule, graph outputs are summed and input into another NN
  """

  def __init__(self,
               n_graph_feat=30,
               n_outputs=30,
               layer_sizes=[1000],
               init='glorot_uniform',
               activation='relu',
               dropout=None,
               max_atoms=50,
               **kwargs):
    """
    Parameters
    ----------
    n_graph_feat: int
      Number of features for each atom
    n_outputs: int
      Number of features for each molecule.
    layer_sizes: list of int, optional(default=[1000])
      Structure of hidden layer(s)
    init: str, optional
      Weight initialization for filters.
    activation: str, optional
      Activation function applied
    dropout: float, optional
      Dropout probability, not supported
    max_atoms: int, optional
      Maximum number of atoms in molecules.
    """
    super(DAGGather, self).__init__(**kwargs)

    self.init = initializations.get(init)  # Set weight initialization
    self.activation = activations.get(activation)  # Get activations
    self.layer_sizes = layer_sizes
    self.dropout = dropout
    self.max_atoms = max_atoms
    self.n_graph_feat = n_graph_feat
    self.n_outputs = n_outputs

  def build(self):
    """"Construct internal trainable weights.
    """

    self.W_list = []
    self.b_list = []
    prev_layer_size = self.n_graph_feat
    for layer_size in self.layer_sizes:
      self.W_list.append(self.init([prev_layer_size, layer_size]))
      self.b_list.append(model_ops.zeros(shape=[
          layer_size,
      ]))
      prev_layer_size = layer_size
    self.W_list.append(self.init([prev_layer_size, self.n_outputs]))
    self.b_list.append(model_ops.zeros(shape=[
        self.n_outputs,
    ]))

    self.trainable_weights = self.W_list + self.b_list

  def call(self, x, mask=None):
    """Execute this layer on input tensors.

    x = graph_features
    
    Parameters
    ----------
    x: tf.Tensor
      Tensor of each atom's graph features

    Returns
    -------
    outputs: tf.Tensor
      Tensor of each molecule's features
      
    """
    # Add trainable weights
    self.build()

    # Extract atom_features
    graph_features = tf.reshape(x, [-1, self.max_atoms, self.n_graph_feat])
    graph_features = tf.reduce_sum(graph_features, axis=1)
    # sum all graph outputs
    outputs = self.DAGgraph_step(graph_features, self.W_list, self.b_list)
    return outputs

  def DAGgraph_step(self, batch_inputs, W_list, b_list):
    outputs = batch_inputs
    for idw, W in enumerate(W_list):
      outputs = tf.nn.xw_plus_b(outputs, W, b_list[idw])
      outputs = self.activation(outputs)
    return outputs
Loading