Unverified Commit 792dbac1 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1615 from peastman/keras

Converted more models to Keras
parents 87410a5e b17a9c17
Loading
Loading
Loading
Loading
+20 −12
Original line number Diff line number Diff line
@@ -311,8 +311,6 @@ class KerasModel(Model):
    the average loss over the most recent checkpoint interval
    """
    self._ensure_built()
    if restore:
      self.restore()
    if checkpoint_interval > 0:
      manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
                                           max_checkpoints_to_keep)
@@ -324,6 +322,9 @@ class KerasModel(Model):

    for batch in generator:
      self._create_training_ops(batch)
      if restore:
        self.restore()
        restore = False
      inputs, labels, weights = self._prepare_batch(batch)
      self._tensorboard_step += 1
      should_log = (
@@ -469,6 +470,8 @@ class KerasModel(Model):
          output_values = self._output_functions[outputs](inputs)
        else:
          output_values = self.model(inputs, training=False)
          if isinstance(output_values, tf.Tensor):
            output_values = [output_values]
          output_values = [t.numpy() for t in output_values]
      else:

@@ -500,8 +503,7 @@ class KerasModel(Model):
        elif len(output_values) == 1:
          output_values = [undo_transforms(output_values[0], transformers)]
      if results is None:
        results = [output_values]
      else:
        results = [[] for i in range(len(output_values))]
      for i, t in enumerate(output_values):
        results[i].append(t)

@@ -890,6 +892,12 @@ class KerasModel(Model):
    else:
      self._checkpoint.restore(checkpoint).run_restore_ops(self.session)

  def get_global_step(self):
    """Get the number of steps of fitting that have been performed."""
    if tf.executing_eagerly():
      return int(self._global_step)
    return self._global_step.eval(session=self.session)


class _StandardLoss(object):
  """The implements the loss function for models that use a dc.models.losses.Loss."""
+40 −10
Original line number Diff line number Diff line
@@ -551,6 +551,39 @@ class WeightedLinearCombo(tf.keras.layers.Layer):
    return out_tensor


class CombineMeanStd(tf.keras.layers.Layer):
  """Generate Gaussian nose."""

  def __init__(self, training_only=False, noise_epsilon=0.01, **kwargs):
    """Create a CombineMeanStd layer.

    This layer should have two inputs with the same shape, and its output also has the
    same shape.  Each element of the output is a Gaussian distributed random number
    whose mean is the corresponding element of the first input, and whose standard
    deviation is the corresponding element of the second input.

    Parameters
    ----------
    training_only: bool
      if True, noise is only generated during training.  During prediction, the output
      is simply equal to the first input (that is, the mean of the distribution used
      during training).
    """
    super(CombineMeanStd, self).__init__(**kwargs)
    self.training_only = training_only

  def call(self, inputs, training=True):
    if len(inputs) != 2:
      raise ValueError("Must have two in_layers")
    mean_parent, std_parent = inputs[0], inputs[1]
    if self.training_only and not training:
      return mean_parent
    from tensorflow.python.ops import array_ops
    sample_noise = tf.random_normal(
        array_ops.shape(mean_parent), 0, 1, dtype=tf.float32)
    return mean_parent + std_parent * sample_noise


class Stack(tf.keras.layers.Layer):
  """Stack the inputs along a new axis."""

@@ -1025,18 +1058,18 @@ class AtomicConvolution(tf.keras.layers.Layer):
    # Compute the distances and radial symmetry functions.
    D = self.distance_tensor(X, Nbrs, self.boxsize, B, N, M, d)
    R = self.distance_matrix(D)
    R = tf.reshape(R, [1] + R.shape.as_list())
    R = tf.expand_dims(R, 0)
    rsf = self.radial_symmetry_function(R, self.rc, self.rs, self.re)

    if not self.atom_types:
      cond = tf.cast(tf.not_equal(Nbrs_Z, 0), tf.float32)
      cond = tf.reshape(cond, R.shape)
      cond = tf.reshape(cond, (1, -1, N, M))
      layer = tf.reduce_sum(cond * rsf, 3)
    else:
      sym = []
      for j in range(len(self.atom_types)):
        cond = tf.cast(tf.equal(Nbrs_Z, self.atom_types[j]), tf.float32)
        cond = tf.reshape(cond, R.shape)
        cond = tf.reshape(cond, (1, -1, N, M))
        sym.append(tf.reduce_sum(cond * rsf, 3))
      layer = tf.concat(sym, 0)

@@ -1132,13 +1165,10 @@ class AtomicConvolution(tf.keras.layers.Layer):
    D: tf.Tensor of shape (B, N, M, d)
      Coordinates/features distance tensor.
    """
    D = []
    for coords, neighbors in zip(tf.unstack(X), tf.unstack(Nbrs)):
      flat_neighbors = tf.reshape(neighbors, [-1])
      neighbor_coords = tf.gather(coords, flat_neighbors)
      neighbor_coords = tf.reshape(neighbor_coords, [N, M, d])
      D.append(neighbor_coords - tf.expand_dims(coords, 1))
    D = tf.stack(D)
    flat_neighbors = tf.reshape(Nbrs, [-1, N * M])
    neighbor_coords = tf.batch_gather(X, flat_neighbors)
    neighbor_coords = tf.reshape(neighbor_coords, [-1, N, M, d])
    D = neighbor_coords - tf.expand_dims(X, 2)
    if boxsize is not None:
      boxsize = tf.reshape(boxsize, [1, 1, 1, d])
      D -= tf.round(D / boxsize) * boxsize
+32 −103
Original line number Diff line number Diff line
@@ -6,13 +6,10 @@ import numpy as np
import tensorflow as tf

from deepchem.utils.save import log
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.models.tensorgraph.layers import Layer, SigmoidCrossEntropy, \
    Sigmoid, Feature, Label, Weights, Concat, WeightedError, Stack
from deepchem.models.tensorgraph.layers import convert_to_layers
from deepchem.models import KerasModel, layers
from deepchem.models.losses import SigmoidCrossEntropy
from deepchem.trans import undo_transforms

logger = logging.getLogger(__name__)
from tensorflow.keras.layers import Input, Layer, Activation, Concatenate, Lambda


class IRVLayer(Layer):
@@ -20,7 +17,7 @@ class IRVLayer(Layer):
       https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2750043/
  """

  def __init__(self, n_tasks, K, **kwargs):
  def __init__(self, n_tasks, K, penalty, **kwargs):
    """
    Parameters
    ----------
@@ -31,23 +28,16 @@ class IRVLayer(Layer):
    """
    self.n_tasks = n_tasks
    self.K = K
    self.V, self.W, self.b, self.b2 = None, None, None, None
    self.penalty = penalty
    super(IRVLayer, self).__init__(**kwargs)

  def build(self):
  def build(self, input_shape):
    self.V = tf.Variable(tf.constant([0.01, 1.]), name="vote", dtype=tf.float32)
    self.W = tf.Variable(tf.constant([1., 1.]), name="w", dtype=tf.float32)
    self.b = tf.Variable(tf.constant([0.01]), name="b", dtype=tf.float32)
    self.b2 = tf.Variable(tf.constant([0.01]), name="b2", dtype=tf.float32)
    self.trainable_weights = [self.V, self.W, self.b, self.b2]

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)

    self.build()
    inputs = in_layers[0].out_tensor
  def call(self, inputs):
    K = self.K
    outputs = []
    for count in range(self.n_tasks):
@@ -61,53 +51,10 @@ class IRVLayer(Layer):
      R = tf.sigmoid(R)
      z = tf.reduce_sum(R * tf.gather(self.V, ys), axis=1) + self.b2
      outputs.append(tf.reshape(z, shape=[-1, 1]))
    out_tensor = tf.concat(outputs, axis=1)

    if set_tensors:
      self.trainable_variables = self.trainable_weights
      self.out_tensor = out_tensor
    return out_tensor

  def none_tensors(self):
    V, W, b, b2 = self.V, self.W, self.b, self.b2
    self.V, self.W, self.b, self.b2 = None, None, None, None

    out_tensor, trainable_weights, variables = self.out_tensor, self.trainable_weights, self.trainable_variables
    self.out_tensor, self.trainable_weights, self.trainable_variables = None, [], []
    return V, W, b, b2, out_tensor, trainable_weights, variables

  def set_tensors(self, tensor):
    self.V, self.W, self.b, self.b2, self.out_tensor, self.trainable_weights, self.trainable_variables = tensor


class IRVRegularize(Layer):
  """ Extracts the trainable weights in IRVLayer
  and return their L2-norm
  No in_layers is required, but should be built after target IRVLayer
  """

  def __init__(self, IRVLayer, penalty=0.0, **kwargs):
    """
    Parameters
    ----------
    IRVLayer: IRVLayer
      Target layer for extracting weights and regularization
    penalty: float
      L2 Penalty strength
    """
    self.IRVLayer = IRVLayer
    self.penalty = penalty
    super(IRVRegularize, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    assert self.IRVLayer.out_tensor is not None, "IRVLayer must be built first"
    out_tensor = tf.nn.l2_loss(self.IRVLayer.W) + \
        tf.nn.l2_loss(self.IRVLayer.V) + tf.nn.l2_loss(self.IRVLayer.b) + \
        tf.nn.l2_loss(self.IRVLayer.b2)
    out_tensor = out_tensor * self.penalty
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor
    loss = (tf.nn.l2_loss(self.W) + tf.nn.l2_loss(self.V) + tf.nn.l2_loss(
        self.b) + tf.nn.l2_loss(self.b2)) * self.penalty
    self.add_loss(loss)
    return tf.concat(outputs, axis=1)


class Slice(Layer):
@@ -129,22 +76,13 @@ class Slice(Layer):
    self.axis = axis
    super(Slice, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)

  def call(self, inputs):
    slice_num = self.slice_num
    axis = self.axis
    inputs = in_layers[0].out_tensor
    out_tensor = tf.slice(inputs, [0] * axis + [slice_num], [-1] * axis + [1])
    return tf.slice(inputs, [0] * axis + [slice_num], [-1] * axis + [1])

    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class TensorflowMultitaskIRVClassifier(TensorGraph):
class TensorflowMultitaskIRVClassifier(KerasModel):

  def __init__(self,
               n_tasks,
@@ -167,34 +105,25 @@ class TensorflowMultitaskIRVClassifier(TensorGraph):
    self.n_tasks = n_tasks
    self.K = K
    self.n_features = 2 * self.K * self.n_tasks
    logger.info("n_features after fit_transform: %d" % int(self.n_features))
    self.penalty = penalty
    super(TensorflowMultitaskIRVClassifier, self).__init__(**kwargs)
    self.build_graph()

  def build_graph(self):
    """Constructs the graph architecture of IRV as described in:

       https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2750043/
    """
    self.mol_features = Feature(shape=(None, self.n_features))
    self._labels = Label(shape=(None, self.n_tasks))
    self._weights = Weights(shape=(None, self.n_tasks))
    predictions = IRVLayer(self.n_tasks, self.K, in_layers=[self.mol_features])
    costs = []
    mol_features = Input(shape=(self.n_features,))
    predictions = IRVLayer(self.n_tasks, self.K, self.penalty)(mol_features)
    logits = []
    outputs = []
    for task in range(self.n_tasks):
      task_output = Slice(task, 1, in_layers=[predictions])
      sigmoid = Sigmoid(in_layers=[task_output])
      task_output = Slice(task, 1)(predictions)
      sigmoid = Activation(tf.sigmoid)(task_output)
      logits.append(task_output)
      outputs.append(sigmoid)

      label = Slice(task, axis=1, in_layers=[self._labels])
      cost = SigmoidCrossEntropy(in_layers=[label, task_output])
      costs.append(cost)
    all_cost = Concat(in_layers=costs, axis=1)
    loss = WeightedError(in_layers=[all_cost, self._weights]) + \
        IRVRegularize(predictions, self.penalty, in_layers=[predictions])
    self.set_loss(loss)
    outputs = Stack(axis=1, in_layers=outputs)
    outputs = Concat(axis=2, in_layers=[1 - outputs, outputs])
    self.add_output(outputs)
    outputs = layers.Stack(axis=1)(outputs)
    outputs2 = Lambda(lambda x: 1 - x)(outputs)
    outputs = [
        Concatenate(axis=2)([outputs2, outputs]),
        Concatenate(axis=1)(logits)
    ]
    model = tf.keras.Model(inputs=[mol_features], outputs=outputs)
    super(TensorflowMultitaskIRVClassifier, self).__init__(
        model,
        SigmoidCrossEntropy(),
        output_types=['prediction', 'loss'],
        **kwargs)
+80 −112
Original line number Diff line number Diff line
@@ -7,15 +7,17 @@ __license__ = "MIT"

import sys

from deepchem.models.tensorgraph.layers import Layer, Feature, Label, AtomicConvolution, L2Loss, ReduceMean
from deepchem.models import TensorGraph
from deepchem.models import KerasModel
from deepchem.models.layers import AtomicConvolution
from deepchem.models.losses import L2Loss
from tensorflow.keras.layers import Input, Layer

import numpy as np
import tensorflow as tf
import itertools


def InitializeWeightsBiases(prev_layer_size,
def initializeWeightsBiases(prev_layer_size,
                            size,
                            weights=None,
                            biases=None,
@@ -58,41 +60,28 @@ def InitializeWeightsBiases(prev_layer_size,
class AtomicConvScore(Layer):

  def __init__(self, atom_types, layer_sizes, **kwargs):
    super(AtomicConvScore, self).__init__(**kwargs)
    self.atom_types = atom_types
    self.layer_sizes = layer_sizes
    super(AtomicConvScore, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    frag1_layer = self.in_layers[0].out_tensor
    frag2_layer = self.in_layers[1].out_tensor
    complex_layer = self.in_layers[2].out_tensor

    frag1_z = self.in_layers[3].out_tensor
    frag2_z = self.in_layers[4].out_tensor
    complex_z = self.in_layers[5].out_tensor

    atom_types = self.atom_types
  def build(self, input_shape):
    self.type_weights = []
    self.type_biases = []
    self.output_weights = []
    self.output_biases = []
    n_features = int(input_shape[0][-1])
    layer_sizes = self.layer_sizes
    num_layers = len(layer_sizes)
    weight_init_stddevs = [1 / np.sqrt(x) for x in layer_sizes]
    bias_init_consts = [0.0] * num_layers

    weights = []
    biases = []
    output_weights = []
    output_biases = []

    n_features = int(frag1_layer.get_shape()[-1])

    for ind, atomtype in enumerate(atom_types):

    for ind, atomtype in enumerate(self.atom_types):
      prev_layer_size = n_features
      weights.append([])
      biases.append([])
      output_weights.append([])
      output_biases.append([])
      self.type_weights.append([])
      self.type_biases.append([])
      self.output_weights.append([])
      self.output_biases.append([])
      for i in range(num_layers):
        weight, bias = InitializeWeightsBiases(
        weight, bias = initializeWeightsBiases(
            prev_layer_size=prev_layer_size,
            size=layer_sizes[i],
            weights=tf.truncated_normal(
@@ -100,24 +89,29 @@ class AtomicConvScore(Layer):
                stddev=weight_init_stddevs[i]),
            biases=tf.constant(
                value=bias_init_consts[i], shape=[layer_sizes[i]]))
        weights[ind].append(weight)
        biases[ind].append(bias)
        self.type_weights[ind].append(weight)
        self.type_biases[ind].append(bias)
        prev_layer_size = layer_sizes[i]
      weight, bias = InitializeWeightsBiases(prev_layer_size, 1)
      output_weights[ind].append(weight)
      output_biases[ind].append(bias)
      weight, bias = initializeWeightsBiases(prev_layer_size, 1)
      self.output_weights[ind].append(weight)
      self.output_biases[ind].append(bias)

  def call(self, inputs):
    frag1_layer, frag2_layer, complex_layer, frag1_z, frag2_z, complex_z = inputs
    atom_types = self.atom_types
    num_layers = len(self.layer_sizes)

    def atomnet(current_input, atomtype):
      prev_layer = current_input
      for i in range(num_layers):
        layer = tf.nn.xw_plus_b(prev_layer, weights[atomtype][i],
                                biases[atomtype][i])
        layer = tf.nn.xw_plus_b(prev_layer, self.type_weights[atomtype][i],
                                self.type_biases[atomtype][i])
        layer = tf.nn.relu(layer)
        prev_layer = layer

      output_layer = tf.squeeze(
          tf.nn.xw_plus_b(prev_layer, output_weights[atomtype][0],
                          output_biases[atomtype][0]))
          tf.nn.xw_plus_b(prev_layer, self.output_weights[atomtype][0],
                          self.output_biases[atomtype][0]))
      return output_layer

    frag1_zeros = tf.zeros_like(frag1_z, dtype=tf.float32)
@@ -149,11 +143,10 @@ class AtomicConvScore(Layer):
    frag2_energy = tf.reduce_sum(frag2_outputs, 1)
    complex_energy = tf.reduce_sum(complex_outputs, 1)
    binding_energy = complex_energy - (frag1_energy + frag2_energy)
    self.out_tensor = tf.expand_dims(binding_energy, axis=1)
    return self.out_tensor
    return tf.expand_dims(binding_energy, axis=1)


class AtomicConvModel(TensorGraph):
class AtomicConvModel(KerasModel):

  def __init__(self,
               frag1_num_atoms=70,
@@ -202,7 +195,6 @@ class AtomicConvModel(TensorGraph):
      Learning rate for the model.
    """
    # TODO: Turning off queue for now. Safe to re-activate?
    super(AtomicConvModel, self).__init__(use_queue=False, **kwargs)
    self.complex_num_atoms = complex_num_atoms
    self.frag1_num_atoms = frag1_num_atoms
    self.frag2_num_atoms = frag2_num_atoms
@@ -211,68 +203,52 @@ class AtomicConvModel(TensorGraph):
    self.atom_types = atom_types

    rp = [x for x in itertools.product(*radial)]
    self.frag1_X = Feature(shape=(batch_size, frag1_num_atoms, 3))
    self.frag1_nbrs = Feature(
        shape=(batch_size, frag1_num_atoms, max_num_neighbors))
    self.frag1_nbrs_z = Feature(
        shape=(batch_size, frag1_num_atoms, max_num_neighbors))
    self.frag1_z = Feature(shape=(batch_size, frag1_num_atoms))

    self.frag2_X = Feature(shape=(batch_size, frag2_num_atoms, 3))
    self.frag2_nbrs = Feature(
        shape=(batch_size, frag2_num_atoms, max_num_neighbors))
    self.frag2_nbrs_z = Feature(
        shape=(batch_size, frag2_num_atoms, max_num_neighbors))
    self.frag2_z = Feature(shape=(batch_size, frag2_num_atoms))

    self.complex_X = Feature(shape=(batch_size, complex_num_atoms, 3))
    self.complex_nbrs = Feature(
        shape=(batch_size, complex_num_atoms, max_num_neighbors))
    self.complex_nbrs_z = Feature(
        shape=(batch_size, complex_num_atoms, max_num_neighbors))
    self.complex_z = Feature(shape=(batch_size, complex_num_atoms))
    frag1_X = Input(shape=(frag1_num_atoms, 3))
    frag1_nbrs = Input(shape=(frag1_num_atoms, max_num_neighbors))
    frag1_nbrs_z = Input(shape=(frag1_num_atoms, max_num_neighbors))
    frag1_z = Input(shape=(frag1_num_atoms,))

    frag2_X = Input(shape=(frag2_num_atoms, 3))
    frag2_nbrs = Input(shape=(frag2_num_atoms, max_num_neighbors))
    frag2_nbrs_z = Input(shape=(frag2_num_atoms, max_num_neighbors))
    frag2_z = Input(shape=(frag2_num_atoms,))

    complex_X = Input(shape=(complex_num_atoms, 3))
    complex_nbrs = Input(shape=(complex_num_atoms, max_num_neighbors))
    complex_nbrs_z = Input(shape=(complex_num_atoms, max_num_neighbors))
    complex_z = Input(shape=(complex_num_atoms,))

    frag1_conv = AtomicConvolution(
        atom_types=self.atom_types,
        radial_params=rp,
        boxsize=None,
        in_layers=[self.frag1_X, self.frag1_nbrs, self.frag1_nbrs_z])
        atom_types=self.atom_types, radial_params=rp,
        boxsize=None)([frag1_X, frag1_nbrs, frag1_nbrs_z])

    frag2_conv = AtomicConvolution(
        atom_types=self.atom_types,
        radial_params=rp,
        boxsize=None,
        in_layers=[self.frag2_X, self.frag2_nbrs, self.frag2_nbrs_z])
        atom_types=self.atom_types, radial_params=rp,
        boxsize=None)([frag2_X, frag2_nbrs, frag2_nbrs_z])

    complex_conv = AtomicConvolution(
        atom_types=self.atom_types,
        radial_params=rp,
        boxsize=None,
        in_layers=[self.complex_X, self.complex_nbrs, self.complex_nbrs_z])

    score = AtomicConvScore(
        self.atom_types,
        layer_sizes,
        in_layers=[
            frag1_conv, frag2_conv, complex_conv, self.frag1_z, self.frag2_z,
            self.complex_z
        ])

    self.label = Label(shape=(None, 1))
    loss = ReduceMean(in_layers=L2Loss(in_layers=[score, self.label]))
    self.add_output(score)
    self.set_loss(loss)
        atom_types=self.atom_types, radial_params=rp,
        boxsize=None)([complex_X, complex_nbrs, complex_nbrs_z])

    score = AtomicConvScore(self.atom_types, layer_sizes)(
        [frag1_conv, frag2_conv, complex_conv, frag1_z, frag2_z, complex_z])

    model = tf.keras.Model(
        inputs=[
            frag1_X, frag1_nbrs, frag1_nbrs_z, frag1_z, frag2_X, frag2_nbrs,
            frag2_nbrs_z, frag2_z, complex_X, complex_nbrs, complex_nbrs_z,
            complex_z
        ],
        outputs=score)
    super(AtomicConvModel, self).__init__(
        model, L2Loss(), batch_size=batch_size, **kwargs)

  def default_generator(self,
                        dataset,
                        epochs=1,
                        predict=False,
                        mode='fit',
                        deterministic=True,
                        pad_batches=True):
    complex_num_atoms = self.complex_num_atoms
    frag1_num_atoms = self.frag1_num_atoms
    frag2_num_atoms = self.frag2_num_atoms
    max_num_neighbors = self.max_num_neighbors
    batch_size = self.batch_size

    def replace_atom_types(z):
@@ -288,28 +264,24 @@ class AtomicConvModel(TensorGraph):
      for ind, (F_b, y_b, w_b, ids_b) in enumerate(
          dataset.iterbatches(
              batch_size, deterministic=True, pad_batches=pad_batches)):
        N = complex_num_atoms
        N_1 = frag1_num_atoms
        N_2 = frag2_num_atoms
        M = max_num_neighbors
        N = self.complex_num_atoms
        N_1 = self.frag1_num_atoms
        N_2 = self.frag2_num_atoms
        M = self.max_num_neighbors

        orig_dict = {}
        batch_size = F_b.shape[0]
        num_features = F_b[0][0].shape[1]
        frag1_X_b = np.zeros((batch_size, N_1, num_features))
        for i in range(batch_size):
          frag1_X_b[i] = F_b[i][0]
        orig_dict[self.frag1_X] = frag1_X_b

        frag2_X_b = np.zeros((batch_size, N_2, num_features))
        for i in range(batch_size):
          frag2_X_b[i] = F_b[i][3]
        orig_dict[self.frag2_X] = frag2_X_b

        complex_X_b = np.zeros((batch_size, N, num_features))
        for i in range(batch_size):
          complex_X_b[i] = F_b[i][6]
        orig_dict[self.complex_X] = complex_X_b

        frag1_Nbrs = np.zeros((batch_size, N_1, M))
        frag1_Z_b = np.zeros((batch_size, N_1))
@@ -323,9 +295,6 @@ class AtomicConvModel(TensorGraph):
            frag1_Nbrs[i, atom, :len(atom_nbrs)] = np.array(atom_nbrs)
            for j, atom_j in enumerate(atom_nbrs):
              frag1_Nbrs_Z[i, atom, j] = frag1_Z_b[i, atom_j]
        orig_dict[self.frag1_nbrs] = frag1_Nbrs
        orig_dict[self.frag1_nbrs_z] = frag1_Nbrs_Z
        orig_dict[self.frag1_z] = frag1_Z_b

        frag2_Nbrs = np.zeros((batch_size, N_2, M))
        frag2_Z_b = np.zeros((batch_size, N_2))
@@ -339,9 +308,6 @@ class AtomicConvModel(TensorGraph):
            frag2_Nbrs[i, atom, :len(atom_nbrs)] = np.array(atom_nbrs)
            for j, atom_j in enumerate(atom_nbrs):
              frag2_Nbrs_Z[i, atom, j] = frag2_Z_b[i, atom_j]
        orig_dict[self.frag2_nbrs] = frag2_Nbrs
        orig_dict[self.frag2_nbrs_z] = frag2_Nbrs_Z
        orig_dict[self.frag2_z] = frag2_Z_b

        complex_Nbrs = np.zeros((batch_size, N, M))
        complex_Z_b = np.zeros((batch_size, N))
@@ -356,8 +322,10 @@ class AtomicConvModel(TensorGraph):
            for j, atom_j in enumerate(atom_nbrs):
              complex_Nbrs_Z[i, atom, j] = complex_Z_b[i, atom_j]

        orig_dict[self.complex_nbrs] = complex_Nbrs
        orig_dict[self.complex_nbrs_z] = complex_Nbrs_Z
        orig_dict[self.complex_z] = complex_Z_b
        orig_dict[self.label] = np.reshape(y_b, newshape=(batch_size, 1))
        yield orig_dict
        inputs = [
            frag1_X_b, frag1_Nbrs, frag1_Nbrs_Z, frag1_Z_b, frag2_X_b,
            frag2_Nbrs, frag2_Nbrs_Z, frag2_Z_b, complex_X_b, complex_Nbrs,
            complex_Nbrs_Z, complex_Z_b
        ]
        y_b = np.reshape(y_b, newshape=(batch_size, 1))
        yield (inputs, [y_b], [w_b])
+27 −56

File changed.

Preview size limit exceeded, changes collapsed.

Loading