Commit 1a7996a5 authored by nd-02110114's avatar nd-02110114
Browse files

add docstrings and test

parent b26de50a
Loading
Loading
Loading
Loading
+105 −124
Original line number Diff line number Diff line
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Sequence, Tuple
import numpy as np

from deepchem.utils.typing import RDKitAtom, RDKitBond, RDKitMol
from deepchem.utils.graph_conv_utils import get_atom_type_one_hot, get_atomic_number, \
from deepchem.feat.graph_data import GraphData
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.utils.graph_conv_utils import get_atom_type_one_hot, \
  construct_hydrogen_bonding_info, get_atom_hydrogen_bonding_one_hot, \
  get_atom_is_in_aromatic_one_hot, get_atom_hybridization_one_hot, \
  get_atom_total_num_Hs, get_atom_chirality_one_hot, get_atom_formal_charge, \
  get_atom_partial_charge, get_atom_ring_size_one_hot, get_bond_type_one_hot, \
  get_bond_is_in_same_ring_one_hot, get_bond_graph_distance_one_hot, \
  get_bond_euclidean_distance
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.feat.graph_data import GraphData
  get_atom_total_num_Hs_one_hot, get_atom_chirality_one_hot, get_atom_formal_charge, \
  get_atom_partial_charge, get_atom_ring_size_one_hot, get_atom_total_degree_one_hot, \
  get_bond_type_one_hot, get_bond_is_in_same_ring_one_hot, get_bond_is_conjugated_one_hot, \
  get_bond_stereo_one_hot


def constrcut_atom_feature(
    atom: RDKitAtom,
    use_mpnn_style: bool,
    hydrogen_bonding: List[Tuple[int, str]],
    chiral_center: Optional[List[Tuple[int, str]]] = None,
    sssr: Optional[Sequence] = None) -> List[Union[int, float]]:
  """TODO: add docstring"""

  # common feature
  atom_type = get_atom_type_one_hot(atom)
  aromatic = get_atom_is_in_aromatic_one_hot(atom)
  hybridization = get_atom_hybridization_one_hot(atom)
  acceptor_donor_one_hot = get_atom_hydrogen_bonding_one_hot(
      atom, hydrogen_bonding)

  if use_mpnn_style:
    # MPNN style atom vecotor
    atomic_number = get_atomic_number(atom)
    num_Hs = get_atom_total_num_Hs(atom)
    return atom_type + atomic_number + acceptor_donor_one_hot + aromatic + \
      hybridization + num_Hs
def constrcut_atom_feature(atom: RDKitAtom, h_bond_infos: List[Tuple[int, str]],
                           sssr: List[Sequence]) -> List[float]:
  """Construct an atom feature from a RDKit atom object.

  # Weave style atom vector
  if sssr is None or chiral_center is None:
    raise ValueError("Must set the values to `sssr` and `chiral_center`.")
  Parameters
  ---------
  atom: rdkit.Chem.rdchem.Atom
    RDKit atom object
  h_bond_infos: List[Tuple[int, str]]
    A list of tuple `(atom_index, hydrogen_bonding_type)`.
    Basically, it is expected that this value is the return value of
    `construct_hydrogen_bonding_info`. The `hydrogen_bonding_type`
    value is "Acceptor" or "Donor".
  sssr: List[Sequence]
    The return value of `Chem.GetSymmSSSR(mol)`.
    The value is a sequence of rings.

  chirality = get_atom_chirality_one_hot(atom, chiral_center)
  Returns
  -------
  List[Union[int, float]]
    A one-hot vector of the atom feature.
  """
  atom_type = get_atom_type_one_hot(atom)
  chirality = get_atom_chirality_one_hot(atom)
  formal_charge = get_atom_formal_charge(atom)
  partial_charge = get_atom_partial_charge(atom)
  ring_size = get_atom_ring_size_one_hot(atom, sssr)
  hybridization = get_atom_hybridization_one_hot(atom)
  acceptor_donor = get_atom_hydrogen_bonding_one_hot(atom, h_bond_infos)
  aromatic = get_atom_is_in_aromatic_one_hot(atom)
  degree = get_atom_total_degree_one_hot(atom)
  total_num = get_atom_total_num_Hs_one_hot(atom)
  return atom_type + chirality + formal_charge + partial_charge + \
    ring_size + hybridization + acceptor_donor_one_hot + aromatic
    ring_size + hybridization + acceptor_donor + aromatic + degree + total_num


def construct_bond_feature(
    bond: RDKitBond,
    use_mpnn_style: bool,
    graph_dist_matrix: Optional[np.ndarray] = None,
    euclidean_dist_matrix: Optional[np.ndarray] = None,
) -> List[Union[int, float]]:
  """TODO: add docstring"""
def construct_bond_feature(bond: RDKitBond) -> List[float]:
  """Construct a bond feature from a RDKit bond object.

  # common feature
  bond_type = get_bond_type_one_hot(bond)
  Parameters
  ---------
  bond: rdkit.Chem.rdchem.Bond
    RDKit bond object

  if use_mpnn_style:
    # MPNN style bond vecotor
    if euclidean_dist_matrix is None:
      raise ValueError("Must set the value to `euclidean_dist_matrix`.")
    euclidean_distance = get_bond_euclidean_distance(bond,
                                                     euclidean_dist_matrix)
    return bond_type + euclidean_distance

  # Weave style atom vector
  if graph_dist_matrix is None:
    raise ValueError("Must set the value to `graph_dist_matrix`.")
  graph_distance = get_bond_graph_distance_one_hot(bond, graph_dist_matrix)
  Returns
  -------
  List[int]
    A one-hot vector of the bond feature.
  """
  bond_type = get_bond_type_one_hot(bond)
  same_ring = get_bond_is_in_same_ring_one_hot(bond)
  return bond_type + graph_distance + same_ring
  conjugated = get_bond_is_conjugated_one_hot(bond)
  stereo = get_bond_stereo_one_hot(bond)
  return bond_type + same_ring + conjugated + stereo


class MolGraphConvFeaturizer(MolecularFeaturizer):
  """This class is a featurizer of gerneral graph convolution networks for molecules.

  The default featurization is based on WeaveNet style edge and node annotation.
  The default node(atom) and edge(bond) representations are based on WeaveNet paper.
  If you want to use your own representations, you could use this class as a guide
  to define your original Featurizer. In many cases, it's enough to modify return values of
  `constrcut_atom_feature` or `constrcut_bond_feature`.

  The default node representation are constructed by concatenating the following values,
  and the feature length is 25.

  - Atom type: A one-hot vector of this atom, "C", "N", "O", "F", "P", "S", "Br", "I", "other atoms".
  - Chirality: A one-hot vector of the chirality, "R" or "S".
  - Formal charge: Integer electronic charge.
  - Partial charge: Calculated partial charge.
  - Ring sizes: A one-hot vector of the number of rings (3-8) that include this atom.
  - Hybridization: A one-hot vector of "sp", "sp2", "sp3".
  - Hydrogen bonding: A one-hot vector of whether this atom is a hydrogen bond donor or acceptor.
  - Aromatic: A one-hot vector of whether the atom belongs to an aromatic ring.
  - Degree: A one-hot vector of the degree (0-5) of this atom.
  - Number of Hydrogens: A one-hot vector of the number of hydrogens (0-4) that this atom connected.

  TODO: add more docstrings.
  The default edge representation are constructed by concatenating the following values,
  and the feature length is 6.

  - Bond type: A one-hot vector of the bond type, "single", "double", "triple", or "aromatic".
  - Same ring: A one-hot vector of whether the atoms in the pair are in the same ring.
  - Conjugated: A one-hot vector of whether this bond is conjugated or not.
  - Stereo: A one-hot vector of the stereo configuration of a bond.

  If you want to know more details about features, please check the paper [1]_ and
  utilities in deepchem.utils.graph_conv_utils.py.

  Examples
  -------
  --------
  >>> smiles = ["C1CCC1", "C1=CC=CN=C1"]
  >>> featurizer = MolGraphConvFeaturizer()
  >>> out = featurizer.featurize(smiles)
  >>> type(out[0])
  <class 'deepchem.feat.graph_data.GraphData'>

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints."
     Journal of computer-aided molecular design 30.8 (2016):595-608.

  Notes
  -----
  This class requires RDKit to be installed.
  """

  def __init__(self, add_self_loop: bool = False, use_mpnn_style: bool = False):
  def __init__(self, add_self_loop: bool = False):
    """
    Paramters
    ---------
    add_self_loop: bool, default False
      TODO: Docstring
    use_mpnn_style: bool, default False
      TODO: Docstring
      Whether to add self-connected edges or not. If you want to use DGL,
      you sometimes need to add explict self-connected edges.
    """
    self.add_self_loop = add_self_loop
    self.use_mpnn_style = use_mpnn_style

  def _featurize(self, mol: RDKitMol) -> GraphData:
    """Calculate molecule graph features from RDKit mol object.
@@ -117,70 +145,24 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
    """
    try:
      from rdkit import Chem
      from rdkit.Chem import rdmolops, AllChem
      from rdkit.Chem import AllChem
    except ModuleNotFoundError:
      raise ValueError("This method requires RDKit to be installed.")

    # construct atom and bond features
    hydrogen_bonding = construct_hydrogen_bonding_info(mol)
    if self.use_mpnn_style:
      # MPNN style
      # compute 3D coordinate. Sometimes, this operation raise Error
      mol_for_coord = AllChem.AddHs(mol)
      conf_id = AllChem.EmbedMolecule(mol_for_coord)
      mol_for_coord = AllChem.RemoveHs(mol_for_coord)
      dist_matrix = rdmolops.Get3DDistanceMatrix(mol_for_coord, confId=conf_id)

      # construct atom (node) feature
      atom_features = np.array(
          [
              constrcut_atom_feature(atom, self.use_mpnn_style,
                                     hydrogen_bonding)
              for atom in mol.GetAtoms()
          ],
          dtype=np.float,
      )

      # construct edge (bond) information
      src, dist, bond_features = [], [], []
      for bond in mol.GetBonds():
        # add edge list considering a directed graph
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        src += [start, end]
        dist += [end, start]
        bond_features += 2 * [
            construct_bond_feature(
                bond, self.use_mpnn_style, euclidean_dist_matrix=dist_matrix)
        ]

      if self.add_self_loop:
        src += [i for i in range(mol.GetNumAtoms())]
        dist += [i for i in range(mol.GetNumAtoms())]
        bond_fea_length = len(bond_features[0])
        bond_features += 2 * [[0 for _ in range(bond_fea_length)]]

      return GraphData(
          node_features=atom_features,
          edge_index=np.array([src, dist], dtype=np.int),
          edge_features=np.array(bond_features, dtype=np.float))

    # Weave style
    # compute partial charges
    try:
      mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
      pass
    except:
      # If partial charges were not computed
      AllChem.ComputeGasteigerCharges(mol)

    dist_matrix = Chem.GetDistanceMatrix(mol)
    chiral_center = Chem.FindMolChiralCenters(mol)
    h_bond_infos = construct_hydrogen_bonding_info(mol)
    sssr = Chem.GetSymmSSSR(mol)

    # construct atom (node) feature
    atom_features = np.array(
        [
            constrcut_atom_feature(atom, self.use_mpnn_style, hydrogen_bonding,
                                   chiral_center, sssr)
            constrcut_atom_feature(atom, h_bond_infos, sssr)
            for atom in mol.GetAtoms()
        ],
        dtype=np.float,
@@ -193,16 +175,15 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
      start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
      src += [start, end]
      dist += [end, start]
      bond_features += 2 * [
          construct_bond_feature(
              bond, self.use_mpnn_style, graph_dist_matrix=dist_matrix)
      ]
      bond_features += 2 * [construct_bond_feature(bond)]

    if self.add_self_loop:
      src += [i for i in range(mol.GetNumAtoms())]
      dist += [i for i in range(mol.GetNumAtoms())]
      num_atoms = mol.GetNumAtoms()
      src += [i for i in range(num_atoms)]
      dist += [i for i in range(num_atoms)]
      # add dummy edge features
      bond_fea_length = len(bond_features[0])
      bond_features += 2 * [[0 for _ in range(bond_fea_length)]]
      bond_features += num_atoms * [[0 for _ in range(bond_fea_length)]]

    return GraphData(
        node_features=atom_features,
+4 −5
Original line number Diff line number Diff line
import unittest
import pytest
import numpy as np
from deepchem.feat.graph_data import GraphData, BatchGraphData

@@ -38,7 +37,7 @@ class TestGraph(unittest.TestCase):
    assert isinstance(dgl_graph, DGLGraph)

  def test_invalid_graph_data(self):
    with pytest.raises(ValueError):
    with self.assertRaises(ValueError):
      invalid_node_features_type = list(np.random.random_sample((5, 32)))
      edge_index = np.array([
          [0, 1, 2, 2, 3, 4],
@@ -49,7 +48,7 @@ class TestGraph(unittest.TestCase):
          edge_index=edge_index,
      )

    with pytest.raises(ValueError):
    with self.assertRaises(ValueError):
      node_features = np.random.random_sample((5, 32))
      invalid_edge_index_shape = np.array([
          [0, 1, 2, 2, 3, 4],
@@ -60,7 +59,7 @@ class TestGraph(unittest.TestCase):
          edge_index=invalid_edge_index_shape,
      )

    with pytest.raises(ValueError):
    with self.assertRaises(ValueError):
      node_features = np.random.random_sample((5, 5))
      invalid_edge_index_shape = np.array([
          [0, 1, 2, 2, 3, 4],
@@ -72,7 +71,7 @@ class TestGraph(unittest.TestCase):
          edge_index=invalid_edge_index_shape,
      )

    with pytest.raises(TypeError):
    with self.assertRaises(TypeError):
      node_features = np.random.random_sample((5, 32))
      _ = GraphData(node_features=node_features)

+13 −13
Original line number Diff line number Diff line
@@ -3,8 +3,8 @@ import unittest
from deepchem.feat import MolGraphConvFeaturizer


# TODO: Add more test cases
class TestMolGraphConvFeaturizer(unittest.TestCase):

  def test_default_featurizer(self):
    smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
    featurizer = MolGraphConvFeaturizer()
@@ -13,30 +13,30 @@ class TestMolGraphConvFeaturizer(unittest.TestCase):

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 25
    assert graph_feat[0].num_node_features == 38
    assert graph_feat[0].num_edges == 12
    assert graph_feat[0].num_edge_features == 13
    assert graph_feat[0].num_edge_features == 11

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 25
    assert graph_feat[1].num_node_features == 38
    assert graph_feat[1].num_edges == 44
    assert graph_feat[1].num_edge_features == 13
    assert graph_feat[1].num_edge_features == 11

  def test_mpnn_style_featurizer(self):
  def test_featurizer_with_self_loop(self):
    smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
    featurizer = MolGraphConvFeaturizer(use_mpnn_style=True)
    featurizer = MolGraphConvFeaturizer(add_self_loop=True)
    graph_feat = featurizer.featurize(smiles)
    assert len(graph_feat) == 2

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 17
    assert graph_feat[0].num_edges == 12
    assert graph_feat[0].num_edge_features == 5
    assert graph_feat[0].num_node_features == 38
    assert graph_feat[0].num_edges == 12 + 6
    assert graph_feat[0].num_edge_features == 11

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 17
    assert graph_feat[1].num_edges == 44
    assert graph_feat[1].num_edge_features == 5
    assert graph_feat[1].num_node_features == 38
    assert graph_feat[1].num_edges == 44 + 22
    assert graph_feat[1].num_edge_features == 11
+165 −121

File changed.

Preview size limit exceeded, changes collapsed.

+174 −6
Original line number Diff line number Diff line
import unittest


from deepchem.utils.graph_conv_utils import one_hot_encode
from deepchem.utils.graph_conv_utils import one_hot_encode, \
  get_atom_type_one_hot, construct_hydrogen_bonding_info, \
  get_atom_hydrogen_bonding_one_hot, get_atom_is_in_aromatic_one_hot, \
  get_atom_hybridization_one_hot, get_atom_total_num_Hs_one_hot, get_atom_chirality_one_hot, \
  get_atom_formal_charge, get_atom_partial_charge, get_atom_ring_size_one_hot, \
  get_atom_total_degree_one_hot, get_bond_type_one_hot, get_bond_is_in_same_ring_one_hot, \
  get_bond_is_conjugated_one_hot, get_bond_stereo_one_hot, get_bond_graph_distance_one_hot


# TODO: add more test cases
class TestGraphConvUtils(unittest.TestCase):

  def setUp(self):
    from rdkit import Chem
    self.mol = Chem.MolFromSmiles("CN=C=O")  # methyl isocyanate
    self.mol_copper_sulfate = Chem.MolFromSmiles("[Cu+2].[O-]S(=O)(=O)[O-]")
    self.mol_benzene = Chem.MolFromSmiles("c1ccccc1")
    self.mol_s_alanine = Chem.MolFromSmiles("N[C@@H](C)C(=O)O")

  def test_one_hot_encode(self):
    # string set
    assert one_hot_encode("a", ["a", "b", "c"]) == [1, 0, 0]
    assert one_hot_encode("a", ["a", "b", "c"]) == [1.0, 0.0, 0.0]
    # integer set
    assert one_hot_encode(2, [0, 1, 2]) == [0, 0, 1]
    assert one_hot_encode(2, [0.0, 1, 2]) == [0.0, 0.0, 1.0]
    # include_unknown_set is False
    assert one_hot_encode(3, [0, 1, 2]) == [0, 0, 0]
    assert one_hot_encode(3, [0.0, 1, 2]) == [0.0, 0.0, 0.0]
    # include_unknown_set is True
    assert one_hot_encode(3, [0, 1, 2], True) == [0, 0, 0, 1]
    assert one_hot_encode(3, [0.0, 1, 2], True) == [0.0, 0.0, 0.0, 1.0]

  def test_get_atom_type_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_type_one_hot(atoms[0])
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    # check unknown atoms
    atoms = self.mol_copper_sulfate.GetAtoms()
    assert atoms[0].GetSymbol() == "Cu"
    one_hot = get_atom_type_one_hot(atoms[0])
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
    one_hot = get_atom_type_one_hot(atoms[0], include_unknown_set=False)
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    # check original set
    atoms = self.mol.GetAtoms()
    assert atoms[1].GetSymbol() == "N"
    original_set = ["C", "O", "N"]
    one_hot = get_atom_type_one_hot(atoms[1], allowable_set=original_set)
    assert one_hot == [0.0, 0.0, 1.0, 0.0]

  def test_construct_hydrogen_bonding_info(self):
    info = construct_hydrogen_bonding_info(self.mol)
    assert isinstance(info, list)
    assert isinstance(info[0], tuple)
    # Generally, =O behaves as an electron acceptor
    assert info[0] == (3, "Acceptor")

  def test_get_atom_hydrogen_bonding_one_hot(self):
    info = construct_hydrogen_bonding_info(self.mol)
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_hydrogen_bonding_one_hot(atoms[0], info)
    assert one_hot == [0.0, 0.0]

    assert atoms[3].GetSymbol() == "O"
    one_hot = get_atom_hydrogen_bonding_one_hot(atoms[3], info)
    assert one_hot == [0.0, 1.0]

  def test_get_atom_is_in_aromatic_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_is_in_aromatic_one_hot(atoms[0])
    assert one_hot == [0.0]

    atoms = self.mol_benzene.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_is_in_aromatic_one_hot(atoms[0])
    assert one_hot == [1.0]

  def test_get_atom_hybridization_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_hybridization_one_hot(atoms[0])
    assert one_hot == [0.0, 0.0, 1.0]

  def test_get_atom_total_num_Hs_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_total_num_Hs_one_hot(atoms[0])
    assert one_hot == [0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
    assert atoms[3].GetSymbol() == "O"
    one_hot = get_atom_total_num_Hs_one_hot(atoms[3])
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0]

  def test_get_atom_chirality_one_hot(self):
    atoms = self.mol_s_alanine.GetAtoms()
    assert atoms[0].GetSymbol() == "N"
    one_hot = get_atom_chirality_one_hot(atoms[0])
    assert one_hot == [0.0, 0.0]
    assert atoms[1].GetSymbol() == "C"
    one_hot = get_atom_chirality_one_hot(atoms[1])
    assert one_hot == [0.0, 1.0]

  def test_get_atom_formal_charge(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    formal_charge = get_atom_formal_charge(atoms[0])
    assert formal_charge == [0.0]

  def test_get_atom_partial_charge(self):
    from rdkit.Chem import AllChem
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    with self.assertRaises(KeyError):
      get_atom_partial_charge(atoms[0])

    # we must compute partial charges before using `get_atom_partial_charge`
    AllChem.ComputeGasteigerCharges(self.mol)
    partial_charge = get_atom_partial_charge(atoms[0])
    assert len(partial_charge) == 1.0
    assert isinstance(partial_charge[0], float)

  def test_get_atom_ring_size_one_hot(self):
    from rdkit import Chem
    atoms = self.mol.GetAtoms()
    sssr = Chem.GetSymmSSSR(self.mol)
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_ring_size_one_hot(atoms[0], sssr)
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    atoms = self.mol_benzene.GetAtoms()
    sssr = Chem.GetSymmSSSR(self.mol_benzene)
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_ring_size_one_hot(atoms[0], sssr)
    assert one_hot == [0.0, 0.0, 0.0, 1.0, 0.0, 0.0]

  def test_get_atom_total_degree_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_total_degree_one_hot(atoms[0])
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]

    assert atoms[3].GetSymbol() == "O"
    one_hot = get_atom_total_degree_one_hot(atoms[3])
    assert one_hot == [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]

  def test_get_bond_type_one_hot(self):
    bonds = self.mol.GetBonds()
    one_hot = get_bond_type_one_hot(bonds[0])
    # The C-N bond is a single bond
    assert bonds[0].GetBeginAtomIdx() == 0.0
    assert bonds[0].GetEndAtomIdx() == 1.0
    assert one_hot == [1.0, 0.0, 0.0, 0.0]

  def test_get_bond_is_in_same_ring_one_hot(self):
    bonds = self.mol.GetBonds()
    one_hot = get_bond_is_in_same_ring_one_hot(bonds[0])
    assert one_hot == [0.0]

    bonds = self.mol_benzene.GetBonds()
    one_hot = get_bond_is_in_same_ring_one_hot(bonds[0])
    assert one_hot == [1.0]

  def test_get_bond_is_conjugated_one_hot(self):
    bonds = self.mol.GetBonds()
    one_hot = get_bond_is_conjugated_one_hot(bonds[0])
    assert one_hot == [0.0]

    bonds = self.mol_benzene.GetBonds()
    one_hot = get_bond_is_conjugated_one_hot(bonds[0])
    assert one_hot == [1.0]

  def test_get_bond_stereo_one_hot(self):
    bonds = self.mol.GetBonds()
    one_hot = get_bond_stereo_one_hot(bonds[0])
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0]

  def test_get_bond_graph_distance_one_hot(self):
    from rdkit import Chem
    bonds = self.mol.GetBonds()
    dist_matrix = Chem.GetDistanceMatrix(self.mol)
    one_hot = get_bond_graph_distance_one_hot(bonds[0], dist_matrix)
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Loading