Unverified Commit 96de1d14 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2108 from deepchem/N2

Implementing N2 Weave Models
parents d33cc476 139792e3
Loading
Loading
Loading
Loading
+187 −52
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@ from deepchem.feat.atomic_coordinates import ComplexNeighborListFragmentAtomicCo
from deepchem.feat.mol_graphs import ConvMol, WeaveMol
from deepchem.data import DiskDataset
import logging
from typing import Optional, List
from deepchem.utils.typing import RDKitMol, RDKitAtom


def one_of_k_encoding(x, allowable_set):
@@ -398,12 +400,75 @@ def bond_features(bond, use_chirality=False):
  ]
  if use_chirality:
    bond_feats = bond_feats + one_of_k_encoding_unk(
        str(bond.GetStereo()), possible_bond_stereo)
        str(bond.GetStereo()), GraphConvCoonstants.possible_bond_stereo)
  return np.array(bond_feats)


def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
                  graph_distance=True):
def max_pair_distance_pairs(mol: RDKitMol,
                            max_pair_distance: Optional[int]) -> np.ndarray:
  """Helper method which finds atom pairs within max_pair_distance graph distance.

  This helper method is used to find atoms which are within max_pair_distance
  graph_distance of one another. This is done by using the fact that the
  powers of an adjacency matrix encode path connectivity information. In
  particular, if `adj` is the adjacency matrix, then `adj**k` has a nonzero
  value at `(i, j)` if and only if there exists a path of graph distance `k`
  between `i` and `j`. To find all atoms within `max_pair_distance` of each
  other, we can compute the adjacency matrix powers `[adj, adj**2,
  ...,adj**max_pair_distance]` and find pairs which are nonzero in any of
  these matrices. Since adjacency matrices and their powers are positive
  numbers, this is simply the nonzero elements of `adj + adj**2 + ... +
  adj**max_pair_distance`.

  Parameters
  ----------
  mol: rdkit.Chem.rdchem.Mol
    RDKit molecules
  max_pair_distance: Optional[int], (default None)
    This value can be a positive integer or None. This
    parameter determines the maximum graph distance at which pair
    features are computed. For example, if `max_pair_distance==2`,
    then pair features are computed only for atoms at most graph
    distance 2 apart. If `max_pair_distance` is `None`, all pairs are
    considered (effectively infinite `max_pair_distance`)


  Returns
  -------
  np.ndarray
    Of shape `(2, num_pairs)` where `num_pairs` is the total number of pairs
    within `max_pair_distance` of one another.
  """
  from rdkit import Chem
  from rdkit.Chem import rdmolops
  N = len(mol.GetAtoms())
  if (max_pair_distance is None or max_pair_distance >= N):
    max_distance = N
  elif max_pair_distance is not None and max_pair_distance <= 0:
    raise ValueError(
        "max_pair_distance must either be a positive integer or None")
  elif max_pair_distance is not None:
    max_distance = max_pair_distance
  adj = rdmolops.GetAdjacencyMatrix(mol)
  # Handle edge case of self-pairs (i, i)
  sum_adj = np.eye(N)
  for i in range(max_distance):
    # Increment by 1 since we don't want 0-indexing
    power = i + 1
    sum_adj += np.linalg.matrix_power(adj, power)
  nonzero_locs = np.where(sum_adj != 0)
  num_pairs = len(nonzero_locs[0])
  # This creates a matrix of shape (2, num_pairs)
  pair_edges = np.reshape(np.array(list(zip(nonzero_locs))), (2, num_pairs))
  return pair_edges


def pair_features(mol: RDKitMol,
                  bond_features_map: dict,
                  bond_adj_list: List,
                  bt_len: int = 6,
                  graph_distance: bool = True,
                  max_pair_distance: Optional[int] = None) -> np.ndarray:
  """Helper method used to compute atom pair feature vectors.

  Many different featurization methods compute atom pair features
@@ -415,16 +480,26 @@ def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
  ----------
  mol: RDKit Mol
    Molecule to compute features on.
  edge_list: list
    List of edges to consider
  canon_adj_list: list of lists
    `canon_adj_list[i]` is a list of the atom indices that atom `i` shares a
    list. This list is symmetrical so if `j in canon_adj_list[i]` then `i in
    canon_adj_list[j]`.
  bond_features_map: dict 
    Dictionary that maps pairs of atom ids (say `(2, 3)` for a bond between
    atoms 2 and 3) to the features for the bond between them.
  bond_adj_list: list of lists
    `bond_adj_list[i]` is a list of the atom indices that atom `i` shares a
    bond with . This list is symmetrical so if `j in bond_adj_list[i]` then `i
    in bond_adj_list[j]`.
  bt_len: int, optional (default 6)
    The number of different bond types to consider.
  graph_distance: bool, optional (default True)
    If true, use graph distance between molecules. Else use euclidean distance.
    If true, use graph distance between molecules. Else use euclidean
    distance. The specified `mol` must have a conformer. Atomic
    positions will be retrieved by calling `mol.getConformer(0)`.
  max_pair_distance: Optional[int], (default None)
    This value can be a positive integer or None. This
    parameter determines the maximum graph distance at which pair
    features are computed. For example, if `max_pair_distance==2`,
    then pair features are computed only for atoms at most graph
    distance 2 apart. If `max_pair_distance` is `None`, all pairs are
    considered (effectively infinite `max_pair_distance`)

  Note
  ----
@@ -433,32 +508,65 @@ def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
  Returns
  -------
  features: np.ndarray
    Of shape `(N, N, bt_len + max_distance + 1)`. This is the array of pairwise
    features for all atom pairs.
    Of shape `(N_edges, bt_len + max_distance + 1)`. This is the array
    of pairwise features for all atom pairs, where N_edges is the
    number of edges within max_pair_distance of one another in this
    molecules.
  pair_edges: np.ndarray
    Of shape `(2, num_pairs)` where `num_pairs` is the total number of
    pairs within `max_pair_distance` of one another.
  """
  if graph_distance:
    max_distance = 7
  else:
    max_distance = 1
  N = mol.GetNumAtoms()
  features = np.zeros((N, N, bt_len + max_distance + 1))
  pair_edges = max_pair_distance_pairs(mol, max_pair_distance)
  num_pairs = pair_edges.shape[1]
  N_edges = pair_edges.shape[1]
  features = np.zeros((N_edges, bt_len + max_distance + 1))
  # Get mapping
  mapping = {}
  for n in range(N_edges):
    a1, a2 = pair_edges[:, n]
    mapping[(int(a1), int(a2))] = n
  num_atoms = mol.GetNumAtoms()
  rings = mol.GetRingInfo().AtomRings()
  for a1 in range(num_atoms):
    for a2 in canon_adj_list[a1]:
    for a2 in bond_adj_list[a1]:
      # first `bt_len` features are bond features(if applicable)
      features[a1, a2, :bt_len] = np.asarray(
          edge_list[tuple(sorted((a1, a2)))], dtype=float)
      if (int(a1), int(a2)) not in mapping:
        raise ValueError(
            "Malformed molecule with bonds not in specified graph distance.")
      else:
        n = mapping[(int(a1), int(a2))]
      features[n, :bt_len] = np.asarray(
          bond_features_map[tuple(sorted((a1, a2)))], dtype=float)
    for ring in rings:
      if a1 in ring:
        for a2 in ring:
          if (int(a1), int(a2)) not in mapping:
            # For ring pairs outside max pairs distance continue
            continue
          else:
            n = mapping[(int(a1), int(a2))]
          # `bt_len`-th feature is if the pair of atoms are in the same ring
        features[a1, ring, bt_len] = 1
        features[a1, a1, bt_len] = 0.
          if a2 == a1:
            features[n, bt_len] = 0
          else:
            features[n, bt_len] = 1
    # graph distance between two atoms
    if graph_distance:
      # distance is a matrix of 1-hot encoded distances for all atoms
      distance = find_distance(
          a1, num_atoms, canon_adj_list, max_distance=max_distance)
      features[a1, :, bt_len + 1:] = distance
          a1, num_atoms, bond_adj_list, max_distance=max_distance)
      for a2 in range(num_atoms):
        if (int(a1), int(a2)) not in mapping:
          # For ring pairs outside max pairs distance continue
          continue
        else:
          n = mapping[(int(a1), int(a2))]
          features[n, bt_len + 1:] = distance[a2]
  # Euclidean distance between atoms
  if not graph_distance:
    coords = np.zeros((N, 3))
@@ -469,10 +577,11 @@ def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
      np.stack([coords] * N, axis=1) - \
      np.stack([coords] * N, axis=0)), axis=2))

  return features
  return features, pair_edges


def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
def find_distance(a1: RDKitAtom, num_atoms: int, bond_adj_list,
                  max_distance=7) -> np.ndarray:
  """Computes distances from provided atom.

  Parameters
@@ -481,10 +590,10 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
    The source atom to compute distances from.
  num_atoms: int
    The total number of atoms.
  canon_adj_list: list of lists
    `canon_adj_list[i]` is a list of the atom indices that atom `i` shares a
    list. This list is symmetrical so if `j in canon_adj_list[i]` then `i in
    canon_adj_list[j]`.
  bond_adj_list: list of lists
    `bond_adj_list[i]` is a list of the atom indices that atom `i` shares a
    bond with. This list is symmetrical so if `j in bond_adj_list[i]` then `i in
    bond_adj_list[j]`.
  max_distance: int, optional (default 7)
    The max distance to search.

@@ -498,7 +607,7 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
  distance = np.zeros((num_atoms, max_distance))
  radial = 0
  # atoms `radial` bonds away from `a1`
  adj_list = set(canon_adj_list[a1])
  adj_list = set(bond_adj_list[a1])
  # atoms less than `radial` bonds away
  all_list = set([a1])
  while radial < max_distance:
@@ -507,7 +616,7 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
    # find atoms `radial`+1 bonds away
    next_adj = set()
    for adj in adj_list:
      next_adj.update(canon_adj_list[adj])
      next_adj.update(bond_adj_list[adj])
    adj_list = next_adj - all_list
    radial = radial + 1
  return distance
@@ -647,6 +756,14 @@ class WeaveFeaturizer(MolecularFeaturizer):
  descriptors for each pair of atoms. These extra descriptors may provide for
  additional descriptive power but at the cost of a larger featurized dataset.


  Examples
  --------
  >>> import deepchem as dc
  >>> mols = ["C", "CCC"]
  >>> featurizer = dc.feat.WeaveFeaturizer()
  >>> X = featurizer.featurize(mols)

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
@@ -660,18 +777,31 @@ class WeaveFeaturizer(MolecularFeaturizer):

  name = ['weave_mol']

  def __init__(self, graph_distance=True, explicit_H=False,
               use_chirality=False):
    """
  def __init__(self,
               graph_distance: bool = True,
               explicit_H: bool = False,
               use_chirality: bool = False,
               max_pair_distance: Optional[int] = None):
    """Initialize this featurizer with set parameters.

    Parameters
    ----------
    graph_distance: bool, optional
      If true, use graph distance. Otherwise, use Euclidean
      distance.
    explicit_H: bool, optional
    graph_distance: bool, (default True)
      If True, use graph distance for distance features. Otherwise, use
      Euclidean distance. Note that this means that molecules that this
      featurizer is invoked on must have valid conformer information if this
      option is set.
    explicit_H: bool, (default False) 
      If true, model hydrogens in the molecule.
    use_chirality: bool, optional
    use_chirality: bool, (default False)
      If true, use chiral information in the featurization
    max_pair_distance: Optional[int], (default None)
      This value can be a positive integer or None. This
      parameter determines the maximum graph distance at which pair
      features are computed. For example, if `max_pair_distance==2`,
      then pair features are computed only for atoms at most graph
      distance 2 apart. If `max_pair_distance` is `None`, all pairs are
      considered (effectively infinite `max_pair_distance`)
    """
    # Distance is either graph distance(True) or Euclidean distance(False,
    # only support datasets providing Cartesian coordinates)
@@ -682,9 +812,13 @@ class WeaveFeaturizer(MolecularFeaturizer):
    self.explicit_H = explicit_H
    # If uses use_chirality
    self.use_chirality = use_chirality
    if isinstance(max_pair_distance, int) and max_pair_distance <= 0:
      raise ValueError(
          "max_pair_distance must either be a positive integer or None")
    self.max_pair_distance = max_pair_distance
    if self.use_chirality:
      self.bt_len = int(
          GraphConvConstants.bond_fdim_base) + len(possible_bond_stereo)
      self.bt_len = int(GraphConvConstants.bond_fdim_base) + len(
          GraphConvConstants.possible_bond_stereo)
    else:
      self.bt_len = int(GraphConvConstants.bond_fdim_base)

@@ -704,27 +838,28 @@ class WeaveFeaturizer(MolecularFeaturizer):
    nodes = np.vstack(nodes)

    # Get bond lists
    edge_list = {}
    bond_features_map = {}
    for b in mol.GetBonds():
      edge_list[tuple(sorted([b.GetBeginAtomIdx(),
      bond_features_map[tuple(sorted([b.GetBeginAtomIdx(),
                                      b.GetEndAtomIdx()]))] = bond_features(
                                          b, use_chirality=self.use_chirality)

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
    for edge in edge_list.keys():
      canon_adj_list[edge[0]].append(edge[1])
      canon_adj_list[edge[1]].append(edge[0])
    bond_adj_list = [[] for mol_id in range(len(nodes))]
    for bond in bond_features_map.keys():
      bond_adj_list[bond[0]].append(bond[1])
      bond_adj_list[bond[1]].append(bond[0])

    # Calculate pair features
    pairs = pair_features(
    pairs, pair_edges = pair_features(
        mol,
        edge_list,
        canon_adj_list,
        bond_features_map,
        bond_adj_list,
        bt_len=self.bt_len,
        graph_distance=self.graph_distance)
        graph_distance=self.graph_distance,
        max_pair_distance=self.max_pair_distance)

    return WeaveMol(nodes, pairs)
    return WeaveMol(nodes, pairs, pair_edges)


class AtomicConvFeaturizer(ComplexNeighborListFragmentAtomicCoordinates):
+10 −7
Original line number Diff line number Diff line
"""
Data Structures used to represented molecules for convolutions.
"""
__author__ = "Han Altae-Tran and Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import csv
import random
import numpy as np
@@ -375,16 +371,23 @@ class WeaveMol(object):
  """Molecular featurization object for weave convolutions.

  These objects are produced by WeaveFeaturizer, and feed into
  WeaveModel. The underlying implementation is inspired by:
  WeaveModel. The underlying implementation is inspired by [1]_.


  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  """

  def __init__(self, nodes, pairs):
  def __init__(self, nodes, pairs, pair_edges):
    self.nodes = nodes
    self.pairs = pairs
    self.num_atoms = self.nodes.shape[0]
    self.n_features = self.nodes.shape[1]
    self.pair_edges = pair_edges

  def get_pair_edges(self):
    return self.pair_edges

  def get_pair_features(self):
    return self.pairs
+0 −4
Original line number Diff line number Diff line
"""
Tests for ConvMolFeaturizer. 
"""
__author__ = "Han Altae-Tran and Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import unittest
import os
import numpy as np
+129 −0
Original line number Diff line number Diff line
"""
Tests for weave featurizer.
"""
import numpy as np
import deepchem as dc
from deepchem.feat.graph_features import max_pair_distance_pairs


def test_max_pair_distance_pairs():
  """Test that max pair distance pairs are computed properly."""
  from rdkit import Chem
  # Carbon
  mol = Chem.MolFromSmiles('C')
  # Test distance 1
  pair_edges = max_pair_distance_pairs(mol, 1)
  assert pair_edges.shape == (2, 1)
  assert np.all(pair_edges.flatten() == np.array([0, 0]))
  # Test distance 2
  pair_edges = max_pair_distance_pairs(mol, 2)
  assert pair_edges.shape == (2, 1)
  assert np.all(pair_edges.flatten() == np.array([0, 0]))

  # Test alkane
  mol = Chem.MolFromSmiles('CCC')
  # Test distance 1
  pair_edges = max_pair_distance_pairs(mol, 1)
  # 3 self connections and 2 bonds which are both counted twice because of
  # symmetry for 7 total
  assert pair_edges.shape == (2, 7)
  # Test distance 2
  pair_edges = max_pair_distance_pairs(mol, 2)
  # Everything is connected at this distance
  assert pair_edges.shape == (2, 9)


def test_max_pair_distance_infinity():
  """Test that max pair distance pairs are computed properly with infinity distance."""
  from rdkit import Chem
  # Test alkane
  mol = Chem.MolFromSmiles('CCC')
  # Test distance infinity
  pair_edges = max_pair_distance_pairs(mol, None)
  # Everything is connected at this distance
  assert pair_edges.shape == (2, 9)

  # Test pentane
  mol = Chem.MolFromSmiles('CCCCC')
  # Test distance infinity
  pair_edges = max_pair_distance_pairs(mol, None)
  # Everything is connected at this distance
  assert pair_edges.shape == (2, 25)


def test_weave_single_carbon():
  """Test that single carbon atom is featurized properly."""
  mols = ['C']
  featurizer = dc.feat.WeaveFeaturizer()
  #from rdkit import Chem
  mol_list = featurizer.featurize(mols)
  mol = mol_list[0]
  #mol = featurizer._featurize(Chem.MolFromSmiles("C"))

  # Only one carbon
  assert mol.get_num_atoms() == 1

  # Test feature sizes
  assert mol.get_num_features() == 75

  # No bonds, so only 1 pair feature (for the self interaction)
  assert mol.get_pair_features().shape == (1 * 1, 14)


def test_weave_alkane():
  """Test on simple alkane"""
  mols = ['CCC']
  featurizer = dc.feat.WeaveFeaturizer()
  mol_list = featurizer.featurize(mols)
  mol = mol_list[0]

  # 3 carbonds in alkane
  assert mol.get_num_atoms() == 3

  # Test feature sizes
  assert mol.get_num_features() == 75

  # Should be a 3x3 interaction grid
  assert mol.get_pair_features().shape == (3 * 3, 14)


def test_weave_alkane_max_pairs():
  """Test on simple alkane with max pairs distance cutoff"""
  mols = ['CCC']
  featurizer = dc.feat.WeaveFeaturizer(max_pair_distance=1)
  #mol_list = featurizer.featurize(mols)
  #mol = mol_list[0]
  from rdkit import Chem
  mol = featurizer._featurize(Chem.MolFromSmiles(mols[0]))

  # 3 carbonds in alkane
  assert mol.get_num_atoms() == 3

  # Test feature sizes
  assert mol.get_num_features() == 75

  # Should be a 7x14 interaction grid since there are 7 pairs within graph
  # distance 1 (3 self interactions plus 2 bonds counted twice because of
  # symmetry)
  assert mol.get_pair_features().shape == (7, 14)


def test_carbon_nitrogen():
  """Test on carbon nitrogen molecule"""
  # Note there is a central nitrogen of degree 4, with 4 carbons
  # of degree 1 (connected only to central nitrogen).
  mols = ['C[N+](C)(C)C']
  #import rdkit.Chem
  #mols = [rdkit.Chem.MolFromSmiles(s) for s in raw_smiles]
  featurizer = dc.feat.WeaveFeaturizer()
  mols = featurizer.featurize(mols)
  mol = mols[0]

  # 5 atoms in compound
  assert mol.get_num_atoms() == 5

  # Test feature sizes
  assert mol.get_num_features() == 75

  # Should be a 3x3 interaction grid
  assert mol.get_pair_features().shape == (5 * 5, 14)
+66 −35
Original line number Diff line number Diff line
@@ -197,7 +197,6 @@ class WeaveModel(KerasModel):
    self.n_classes = n_classes

    # Build the model.

    atom_features = Input(shape=(self.n_atom_feat[0],))
    pair_features = Input(shape=(self.n_pair_feat[0],))
    pair_split = Input(shape=tuple(), dtype=tf.int32)
@@ -277,6 +276,71 @@ class WeaveModel(KerasModel):
    super(WeaveModel, self).__init__(
        model, loss, output_types=output_types, batch_size=batch_size, **kwargs)

  def compute_features_on_batch(self, X_b):
    """Compute tensors that will be input into the model from featurized representation.

    The featurized input to `WeaveModel` is instances of `WeaveMol` created by
    `WeaveFeaturizer`. This method converts input `WeaveMol` objects into
    tensors used by the Keras implementation to compute `WeaveModel` outputs.

    Parameters
    ----------
    X_b: np.ndarray
      A numpy array with dtype=object where elements are `WeaveMol` objects.

    Returns
    -------
    atom_feat: np.ndarray
      Of shape `(N_atoms, N_atom_feat)`.
    pair_feat: np.ndarray
      Of shape `(N_pairs, N_pair_feat)`. Note that `N_pairs` will depend on
      the number of pairs being considered. If `max_pair_distance` is
      `None`, then this will be `N_atoms**2`. Else it will be the number
      of pairs within the specifed graph distance.
    pair_split: np.ndarray
      Of shape `(N_pairs,)`. The i-th entry in this array will tell you the
      originating atom for this pair (the "source"). Note that pairs are
      symmetric so for a pair `(a, b)`, both `a` and `b` will separately be
      sources at different points in this array.
    atom_split: np.ndarray
      Of shape `(N_atoms,)`. The i-th entry in this array will be the molecule
      with the i-th atom belongs to.
    atom_to_pair: np.ndarray
      Of shape `(N_pairs, 2)`. The i-th row in this array will be the array
      `[a, b]` if `(a, b)` is a pair to be considered. (Note by symmetry, this
      implies some other row will contain `[b, a]`.
    """
    atom_feat = []
    pair_feat = []
    atom_split = []
    atom_to_pair = []
    pair_split = []
    start = 0
    for im, mol in enumerate(X_b):
      n_atoms = mol.get_num_atoms()
      # pair_edges is of shape (2, N)
      pair_edges = mol.get_pair_edges()
      N_pairs = pair_edges[1]
      # number of atoms in each molecule
      atom_split.extend([im] * n_atoms)
      # index of pair features
      C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
      atom_to_pair.append(pair_edges.T + start)
      # Get starting pair atoms
      pair_starts = pair_edges.T[:, 0]
      # number of pairs for each atom
      pair_split.extend(pair_starts + start)
      start = start + n_atoms

      # atom features
      atom_feat.append(mol.get_atom_features())
      # pair features
      pair_feat.append(mol.get_pair_features())

    return (np.concatenate(atom_feat, axis=0), np.concatenate(
        pair_feat, axis=0), np.array(pair_split), np.array(atom_split),
            np.concatenate(atom_to_pair, axis=0))

  def default_generator(
      self,
      dataset: Dataset,
@@ -313,40 +377,7 @@ class WeaveModel(KerasModel):
          if self.mode == 'classification':
            y_b = to_one_hot(y_b.flatten(), self.n_classes).reshape(
                -1, self.n_tasks, self.n_classes)
        atom_feat = []
        pair_feat = []
        atom_split = []
        atom_to_pair = []
        pair_split = []
        start = 0
        for im, mol in enumerate(X_b):
          n_atoms = mol.get_num_atoms()
          # number of atoms in each molecule
          atom_split.extend([im] * n_atoms)
          # index of pair features
          C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
          atom_to_pair.append(
              np.transpose(
                  np.array([C1.flatten() + start,
                            C0.flatten() + start])))
          # number of pairs for each atom
          pair_split.extend(C1.flatten() + start)
          start = start + n_atoms

          # atom features
          atom_feat.append(mol.get_atom_features())
          # pair features
          pair_feat.append(
              np.reshape(mol.get_pair_features(),
                         (n_atoms * n_atoms, self.n_pair_feat[0])))

        inputs = [
            np.concatenate(atom_feat, axis=0),
            np.concatenate(pair_feat, axis=0),
            np.array(pair_split),
            np.array(atom_split),
            np.concatenate(atom_to_pair, axis=0)
        ]
        inputs = self.compute_features_on_batch(X_b)
        yield (inputs, [y_b], [w_b])


Loading