Commit 4bafe42d authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

adding tests and docs

parent ffd1189c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ from deepchem.feat.molecule_featurizers import MACCSKeysFingerprint
from deepchem.feat.molecule_featurizers import MordredDescriptors
from deepchem.feat.molecule_featurizers import Mol2VecFingerprint
from deepchem.feat.molecule_featurizers import MolGraphConvFeaturizer
from deepchem.feat.molecule_featurizers import PagtnMolGraphFeaturizer
from deepchem.feat.molecule_featurizers import MolGanFeaturizer
from deepchem.feat.molecule_featurizers import OneHotFeaturizer
from deepchem.feat.molecule_featurizers import PubChemFingerprint
+1 −0
Original line number Diff line number Diff line
@@ -15,4 +15,5 @@ from deepchem.feat.molecule_featurizers.rdkit_descriptors import RDKitDescriptor
from deepchem.feat.molecule_featurizers.smiles_to_image import SmilesToImage
from deepchem.feat.molecule_featurizers.smiles_to_seq import SmilesToSeq, create_char_to_idx
from deepchem.feat.molecule_featurizers.mol_graph_conv_featurizer import MolGraphConvFeaturizer
from deepchem.feat.molecule_featurizers.mol_graph_conv_featurizer import PagtnMolGraphFeaturizer
from deepchem.feat.molecule_featurizers.molgan_featurizer import MolGanFeaturizer
+7 −8
Original line number Diff line number Diff line
@@ -225,7 +225,7 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
        edge_features=bond_features)


class PAGTNFeaturizer(MolecularFeaturizer):
class PagtnMolGraphFeaturizer(MolecularFeaturizer):

  def __init__(self, max_length=5):

@@ -242,7 +242,7 @@ class PAGTNFeaturizer(MolecularFeaturizer):
    self.ordered_pair = lambda a, b: (a, b) if a < b else (b, a)
    self.max_length = max_length

  def PAGTNAtomFeaturizer(self, atom):
  def PagtnAtomFeaturizer(self, atom):
    atom_type = get_atom_type_one_hot(atom, self.SYMBOLS, False)
    formal_charge = get_atom_formal_charge_one_hot(
        atom, include_unknown_set=False)
@@ -255,7 +255,7 @@ class PAGTNFeaturizer(MolecularFeaturizer):
    ])
    return atom_feat

  def bond_features(self, mol, path_atoms, ring_info):
  def _bond_features(self, mol, path_atoms, ring_info):
    """Computes the edge features for a given pair of nodes.
        Parameters
        ----------
@@ -285,7 +285,6 @@ class PAGTNFeaturizer(MolecularFeaturizer):
        features.append(np.concatenate([bond_type, conjugacy, ring_attach]))
      else:
        features.append(np.zeros(6))
        #features.append([0, 0, 0, 0, 0, 0])

    if path_length + 1 > self.max_length:
      path_length = self.max_length + 1
@@ -304,7 +303,7 @@ class PAGTNFeaturizer(MolecularFeaturizer):
                      one_hot_encode(ring_info, allowable_set=self.RING_TYPES))
    return np.concatenate(features, axis=0)

  def PAGTNBondFeaturizer(self, mol):
  def PagtnBondFeaturizer(self, mol):
    """Featurizes the input molecule.
        Parameters
        ----------
@@ -335,14 +334,14 @@ class PAGTNFeaturizer(MolecularFeaturizer):
          feats.append(np.zeros(7 * self.max_length + 7))
          continue
        ring_info = rings_dict.get(self.ordered_pair(i, j), [])
        feats.append(self.bond_features(mol, paths_dict[(i, j)], ring_info))
        feats.append(self._bond_features(mol, paths_dict[(i, j)], ring_info))

    return np.array([src, dest], dtype=np.int), np.array(feats, dtype=np.float)

  def _featurize(self, mol):
    node_features = np.asarray(
        [self.PAGTNAtomFeaturizer(atom) for atom in mol.GetAtoms()],
        [self.PagtnAtomFeaturizer(atom) for atom in mol.GetAtoms()],
        dtype=np.float)
    edge_index, edge_features = self.PAGTNBondFeaturizer(mol)
    edge_index, edge_features = self.PagtnBondFeaturizer(mol)
    graph = GraphData(node_features, edge_index, edge_features)
    return graph
+22 −0
Original line number Diff line number Diff line
import unittest

from deepchem.feat import MolGraphConvFeaturizer
from deepchem.feat import PagtnMolGraphFeaturizer


class TestMolGraphConvFeaturizer(unittest.TestCase):
@@ -70,3 +71,24 @@ class TestMolGraphConvFeaturizer(unittest.TestCase):
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 31
    assert graph_feat[1].num_edges == 44


class TestPagtnMolGraphConvFeaturizer(unittest.TestCase):

  def test_default_featurizer(self):
    smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
    featurizer = PagtnMolGraphFeaturizer(max_length=5)
    graph_feat = featurizer.featurize(smiles)
    assert len(graph_feat) == 2

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 94
    assert graph_feat[0].num_edges == 36
    assert graph_feat[0].num_edge_features == 42

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 94
    assert graph_feat[1].num_edges == 484
    assert graph_feat[0].num_edge_features == 42