Unverified Commit dcceba58 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1080 from xiongzhp/PRs

add chirality to atom and bond features
parents 415edc77 856ac0d9
Loading
Loading
Loading
Loading
+37 −11
Original line number Diff line number Diff line
@@ -56,11 +56,12 @@ possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP3D2
]
possible_number_radical_e_list = [0, 1, 2]
possible_chirality_list = ['R', 'S']

reference_lists = [
    possible_atom_list, possible_numH_list, possible_valence_list,
    possible_formal_charge_list, possible_number_radical_e_list,
    possible_hybridization_list
    possible_hybridization_list, possible_chirality_list
]

intervals = get_intervals(reference_lists)
@@ -75,7 +76,6 @@ def get_feature_list(atom):
  features[4] = safe_index(possible_number_radical_e_list,
                           atom.GetNumRadicalElectrons())
  features[5] = safe_index(possible_hybridization_list, atom.GetHybridization())

  return features


@@ -111,7 +111,10 @@ def atom_to_id(atom):
  return features_to_id(features, intervals)


def atom_features(atom, bool_id_feat=False, explicit_H=False):
def atom_features(atom,
                  bool_id_feat=False,
                  explicit_H=False,
                  use_chirality=False):
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
  else:
@@ -175,18 +178,31 @@ def atom_features(atom, bool_id_feat=False, explicit_H=False):
    if not explicit_H:
      results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                [0, 1, 2, 3, 4])
    if use_chirality:
      try:
        results = results + one_of_k_encoding_unk(
            atom.GetProp('_CIPCode'),
            ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
      except:
        results = results + [False, False
                            ] + [atom.HasProp('_ChiralityPossible')]

    return np.array(results)


def bond_features(bond):
def bond_features(bond, use_chirality=False):
  bt = bond.GetBondType()
  return np.array([
  bond_feats = [
      bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
      bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(),
      bond.IsInRing()
  ])
  ]
  if use_chirality:
    bond_feats = bond_feats + one_of_k_encoding_unk(
        str(bond.GetStereo(),
            ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]))
  return np.array(bond_feats)


def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
@@ -249,7 +265,7 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
class ConvMolFeaturizer(Featurizer):
  name = ['conv_mol']

  def __init__(self, master_atom=False):
  def __init__(self, master_atom=False, use_chirality=False):
    """
    Parameters
    ----------
@@ -266,11 +282,13 @@ class ConvMolFeaturizer(Featurizer):
    """
    self.dtype = object
    self.master_atom = master_atom
    self.use_chirality = use_chirality

  def _featurize(self, mol):
    """Encodes mol as a ConvMol object."""
    # Get the node features
    idx_nodes = [(a.GetIdx(), atom_features(a)) for a in mol.GetAtoms()]
    idx_nodes = [(a.GetIdx(), atom_features(
        a, use_chirality=self.use_chirality)) for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))

@@ -302,7 +320,8 @@ class ConvMolFeaturizer(Featurizer):
class WeaveFeaturizer(Featurizer):
  name = ['weave_mol']

  def __init__(self, graph_distance=True, explicit_H=False):
  def __init__(self, graph_distance=True, explicit_H=False,
               use_chirality=False):
    # Distance is either graph distance(True) or Euclidean distance(False,
    # only support datasets providing Cartesian coordinates)
    self.graph_distance = graph_distance
@@ -310,11 +329,17 @@ class WeaveFeaturizer(Featurizer):
    self.dtype = object
    # If includes explicit hydrogens
    self.explicit_H = explicit_H
    # If uses use_chirality
    self.use_chirality = use_chirality

  def _featurize(self, mol):
    """Encodes mol as a WeaveMol object."""
    # Atom features
    idx_nodes = [(a.GetIdx(), atom_features(a, explicit_H=self.explicit_H))
    idx_nodes = [(a.GetIdx(),
                  atom_features(
                      a,
                      explicit_H=self.explicit_H,
                      use_chirality=self.use_chirality))
                 for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))
@@ -326,7 +351,8 @@ class WeaveFeaturizer(Featurizer):
    edge_list = {}
    for b in mol.GetBonds():
      edge_list[tuple(sorted([b.GetBeginAtomIdx(),
                              b.GetEndAtomIdx()]))] = bond_features(b)
                              b.GetEndAtomIdx()]))] = bond_features(
                                  b, use_chirality=self.use_chirality)

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]