Commit c93cb545 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

new version of generator

parent 109d8ed2
Loading
Loading
Loading
Loading
+157 −60
Original line number Diff line number Diff line
from typing import List, Tuple
from typing import List, Tuple, Any

import tensorflow as tf
from deepchem.feat.molecule_featurizers.molgan_featurizer import GraphMatrix
@@ -119,63 +119,20 @@ class BasicMolGANModel(WGAN):
    Take noise data as an input and processes it through number of
    dense and dropout layers. Then data is converted into two forms
    one used for training and other for generation of compounds.
    The model has four outputs:
      1. edges logits used during training
      2. nodes logits used during training
      3. edges logits used for compound generation
      4. nodes logits used for compound generation
    """
    input_layer = layers.Input(shape=(self.embedding_dim,))
    x = layers.Dense(128, activation="tanh")(input_layer)
    x = layers.Dropout(self.dropout_rate)(x)
    x = layers.Dense(256, activation="tanh")(x)
    x = layers.Dropout(self.dropout_rate)(x)
    x = layers.Dense(512, activation="tanh")(x)
    x = layers.Dropout(self.dropout_rate)(x)

    # edges logits used during training
    edges_logits = layers.Dense(
        units=self.edges * self.vertices * self.vertices, activation=None)(x)
    edges_logits = layers.Reshape((self.edges, self.vertices,
                                   self.vertices))(edges_logits)
    matrix_transpose = layers.Permute((1, 3, 2))(edges_logits)
    edges_logits = (edges_logits + matrix_transpose) / 2
    edges_logits = layers.Permute((2, 3, 1))(edges_logits)
    edges_logits = layers.Dropout(self.dropout_rate)(edges_logits)
    edges_softmax = tf.nn.softmax(edges_logits)

    # nodes logits used during training
    nodes_logits = layers.Dense(
        units=(self.vertices * self.nodes), activation=None)(x)
    nodes_logits = layers.Reshape((self.vertices, self.nodes))(nodes_logits)
    nodes_logits = layers.Dropout(self.dropout_rate)(nodes_logits)
    nodes_softmax = tf.nn.softmax(nodes_logits)

    # edges logits used for compound generation
    e_gumbel_logits = edges_logits - tf.math.log(-tf.math.log(
        tf.random.uniform(tf.shape(edges_logits), dtype=edges_logits.dtype)))
    e_gumbel_argmax = tf.one_hot(
        tf.argmax(e_gumbel_logits, axis=-1),
        depth=e_gumbel_logits.shape[-1],
        dtype=e_gumbel_logits.dtype,
    )
    e_argmax = tf.argmax(e_gumbel_argmax, axis=-1)

    # nodes logits used during compound generation
    n_gumbel_logits = nodes_logits - tf.math.log(-tf.math.log(
        tf.random.uniform(tf.shape(nodes_logits), dtype=nodes_logits.dtype)))
    n_gumbel_argmax = tf.one_hot(
        tf.argmax(n_gumbel_logits, axis=-1),
        depth=n_gumbel_logits.shape[-1],
        dtype=n_gumbel_logits.dtype,
    )
    n_argmax = tf.argmax(n_gumbel_argmax, axis=-1)

    # final model, first 2 outputs are for training, last two are for compound generation
    return keras.Model(
        inputs=input_layer,
        outputs=[edges_softmax, nodes_softmax, e_argmax, n_argmax],
    )
    The model has two outputs:
      1. edges
      2. nodes
    The format differs depending on intended use (training or sample generation).
    For sample generation use flag, sample_generation=True while calling generator
    i.e. gan.generators[0](noise_input, training=False, sample_generation=True).
    In case of training, not flag is necessary.
    """
    return BasicMolGANGenerator(
        vertices=self.vertices,
        edges=self.edges,
        nodes=self.nodes,
        dropout_rate=self.dropout_rate,
        embedding_dim=self.embedding_dim)

  def create_discriminator(self) -> keras.Model:
    """
@@ -245,10 +202,150 @@ class BasicMolGANModel(WGAN):
      batch_size = len(noise_input)
    if noise_input is None:
      noise_input = self.get_noise_batch(batch_size)
    _, _, adjacency_matrix, nodes_features = self.generators[0](
        noise_input, training=False)
    print(f"Generating {batch_size} samples")
    adjacency_matrix, nodes_features = self.generators[0](
        noise_input, training=False, sample_generation=True)
    graphs = [
        GraphMatrix(i, j)
        for i, j in zip(adjacency_matrix.numpy(), nodes_features.numpy())
    ]
    return graphs


class BasicMolGANGenerator(tf.keras.Model):
  """
  Generator class for BasicMolGAN model.
  Using subclassing rather than functional API due to requirement
  to swap between two outputs depending on situation.
  In order to get output that used for sample generation
  (conversion to rdkit molecules) pass sample_generation=True argument while
  calling the model i.e. adjacency_matrix, nodes_features = self.generators[0](
  noise_input, training=False, sample_generation=True)
  This is automatically done in predict_gan_generator().
  """

  def __init__(self,
               vertices: int = 9,
               edges: int = 5,
               nodes: int = 5,
               dropout_rate: float = 0.,
               embedding_dim: int = 10,
               name: str = "SimpleMolGANGenerator",
               **kwargs):
    """
    Initialize model.

    Parameters
    ----------
    vertices : int, optional
        number of max atoms dataset molecules (incl. empty atom), by default 9
    edges : int, optional
        number of bond types in molecules, by default 5
    nodes : int, optional
        number of atom types in molecules, by default 5
    dropout_rate : float, optional
        rate of dropout, by default 0.
    embedding_dim : int, optional
        noise input dimensions, by default 10
    name : str, optional
        name of the model, by default "SimpleMolGANGenerator"
    """
    super(BasicMolGANGenerator, self).__init__(name=name, **kwargs)
    self.vertices = vertices
    self.edges = edges
    self.nodes = nodes
    self.dropout_rate = dropout_rate
    self.embedding_dim = embedding_dim

    self.dense1 = layers.Dense(
        128, activation="tanh", input_shape=(self.embedding_dim,))
    self.dropout1 = layers.Dropout(self.dropout_rate)
    self.dense2 = layers.Dense(256, activation="tanh")
    self.dropout2 = layers.Dropout(self.dropout_rate)
    self.dense3 = layers.Dense(512, activation="tanh")
    self.dropout3 = layers.Dropout(self.dropout_rate)

    # edges logits used during training
    self.edges_dense = layers.Dense(
        units=self.edges * self.vertices * self.vertices, activation=None)
    self.edges_reshape = layers.Reshape((self.edges, self.vertices,
                                         self.vertices))
    self.edges_matrix_transpose1 = layers.Permute((1, 3, 2))
    self.edges_matrix_transpose2 = layers.Permute((2, 3, 1))
    self.edges_dropout = layers.Dropout(self.dropout_rate)

    # nodes logits used during training
    self.nodes_dense = layers.Dense(
        units=(self.vertices * self.nodes), activation=None)
    self.nodes_reshape = layers.Reshape((self.vertices, self.nodes))
    self.nodes_dropout = layers.Dropout(self.dropout_rate)

  def call(self,
           inputs: Any,
           training: bool = False,
           sample_generation: bool = False) -> List[Any]:
    """
    Call generator model

    Parameters
    ----------
    inputs : Any
        List of inputs, typically noise_batch
    training : bool, optional
        used by dropout layers, by default False
    sample_generation : bool, optional
        decide which output to use, by default False

    Returns
    -------
    List[Any, Any]
        Tensors containing either softmax values for training
        or argmax for sample generation (used for creation of rdkit molecules).
    """

    x = self.dense1(inputs)
    x = self.dropout1(x)
    x = self.dense2(x)
    x = self.dropout2(x)
    x = self.dense3(x)
    x = self.dropout3(x)

    # edges logits
    edges_logits = self.edges_dense(x)
    edges_logits = self.edges_reshape(edges_logits)
    matrix_transpose = self.edges_matrix_transpose1(edges_logits)
    edges_logits = (edges_logits + matrix_transpose) / 2
    edges_logits = self.edges_matrix_transpose2(edges_logits)
    edges_logits = self.edges_dropout(edges_logits)

    # nodes logits
    nodes_logits = self.nodes_dense(x)
    nodes_logits = self.nodes_reshape(nodes_logits)
    nodes_logits = self.nodes_dropout(nodes_logits)

    if sample_generation is False:
      # training of the model
      edges = tf.nn.softmax(edges_logits)
      nodes = tf.nn.softmax(nodes_logits)
    else:
      # generating compounds
      e_gumbel_logits = edges_logits - tf.math.log(-tf.math.log(
          tf.random.uniform(tf.shape(edges_logits), dtype=edges_logits.dtype)))
      e_gumbel_argmax = tf.one_hot(
          tf.argmax(e_gumbel_logits, axis=-1),
          depth=e_gumbel_logits.shape[-1],
          dtype=e_gumbel_logits.dtype,
      )
      edges = tf.argmax(e_gumbel_argmax, axis=-1)

      # nodes logits used during compound generation
      n_gumbel_logits = nodes_logits - tf.math.log(-tf.math.log(
          tf.random.uniform(tf.shape(nodes_logits), dtype=nodes_logits.dtype)))
      n_gumbel_argmax = tf.one_hot(
          tf.argmax(n_gumbel_logits, axis=-1),
          depth=n_gumbel_logits.shape[-1],
          dtype=n_gumbel_logits.dtype,
      )
      nodes = tf.argmax(n_gumbel_argmax, axis=-1)

    return [edges, nodes]
+1 −5
Original line number Diff line number Diff line
@@ -67,11 +67,6 @@ class test_molgan_model(unittest.TestCase):
    # check training nodes logits shapes
    assert model.generators[0].output_shape[1] == (None, self.vertices,
                                                   self.nodes)
    # check molecule generation edges logits shapes
    assert model.generators[0].output_shape[2] == (None, self.vertices,
                                                   self.vertices)
    # check molecule generation nodes logits shapes
    assert model.generators[0].output_shape[3] == (None, self.vertices)

  def test_training(self):
    """
@@ -123,6 +118,7 @@ class test_molgan_model(unittest.TestCase):
      generated_molecules = feat.defeaturize(g)
      valid_molecules_count = len(
          list(filter(lambda x: x is not None, generated_molecules)))
      print(valid_molecules_count)
      if valid_molecules_count:
        success = True
        break