Commit 8535d72d authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Concatenated the 3 returned arrays to make one large array

parent 458164ea
Loading
Loading
Loading
Loading
+15 −8
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from sklearn.metrics import pairwise_distances
class MATFeaturizer(MolecularFeaturizer):
  """
  This class is a featurizer for the Molecule Attention Transformer [1]_.
  The featurizer accepts an RDKit Molecule, and 2 booleans (add_dummy_node and one_hot_formal_charge) as arguments.
  The featurizer accepts an RDKit Molecule, and a boolean (one_hot_formal_charge) as arguments.
  The returned value is a numpy array which consists of molecular graph descriptions:
    - Node Features
    - Adjacency Matrix
@@ -28,7 +28,6 @@ class MATFeaturizer(MolecularFeaturizer):

  def __init__(
      self,
      add_dummy_node: bool = True,
      one_hot_formal_charge: bool = True,
  ):
    """
@@ -40,7 +39,6 @@ class MATFeaturizer(MolecularFeaturizer):
      If True, formal charges on atoms are one-hot encoded.
    """

    self.add_dummy_node = add_dummy_node
    self.one_hot_formal_charge = one_hot_formal_charge

  def atom_features(self, atom):
@@ -86,7 +84,7 @@ class MATFeaturizer(MolecularFeaturizer):
    
    Returns
    -------
    numpy.ndarray: (node_features, adjacency_matrix, distance_matrix)
    Tuple[np.ndarray]: (node_features, adjacency_matrix, distance_matrix)
    """

    node_features = np.array(
@@ -96,8 +94,17 @@ class MATFeaturizer(MolecularFeaturizer):

    distance_matrix = Chem.rdmolops.GetDistanceMatrix(mol)

    adjacency_matrix.resize(node_features.shape)
    distance_matrix.resize(node_features.shape)
    result = node_features

    return node_features, adjacency_matrix, distance_matrix
    result = np.zeros((node_features.shape[0],
                       node_features.shape[1] + 2 * adjacency_matrix.shape[1]))

    for i in range(node_features.shape[0]):
      result[i, :node_features.shape[1]] = node_features[i]
      result[i, node_features.shape[1]:node_features.shape[1] +
             adjacency_matrix.shape[1]] = adjacency_matrix[i]
      result[i, node_features.shape[1] + adjacency_matrix.shape[1]:
             node_features.shape[1] + adjacency_matrix.shape[1] +
             distance_matrix.shape[1]] = distance_matrix[i]

    return result