Commit 856ac0d9 authored by ErikXiong's avatar ErikXiong
Browse files

Adopted Yutong's suggestions using R/S and E/Z labeling for atoms and bonds,...

Adopted Yutong's suggestions using R/S and E/Z labeling for atoms and bonds, and formatted the changed file with yapf.
I am surprised to find ConvMolFeaturizer doesn't even use bond_features.
parent 50f076da
Loading
Loading
Loading
Loading
+37 −21
Original line number Diff line number Diff line
@@ -50,15 +50,13 @@ possible_atom_list = [
possible_numH_list = [0, 1, 2, 3, 4]
possible_valence_list = [0, 1, 2, 3, 4, 5, 6]
possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3]
possible_chirality_list = [
    Chem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.ChiralType.CHI_UNSPECIFIED]
possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
possible_number_radical_e_list = [0, 1, 2]
possible_chirality_list = ['R', 'S']

reference_lists = [
    possible_atom_list, possible_numH_list, possible_valence_list,
@@ -70,7 +68,7 @@ intervals = get_intervals(reference_lists)


def get_feature_list(atom):
  features = 7 * [0]
  features = 6 * [0]
  features[0] = safe_index(possible_atom_list, atom.GetSymbol())
  features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs())
  features[2] = safe_index(possible_valence_list, atom.GetImplicitValence())
@@ -78,7 +76,6 @@ def get_feature_list(atom):
  features[4] = safe_index(possible_number_radical_e_list,
                           atom.GetNumRadicalElectrons())
  features[5] = safe_index(possible_hybridization_list, atom.GetHybridization())
  features[6] = safe_index(possible_chirality_list, atom.GetChiralTag())
  return features


@@ -94,15 +91,15 @@ def features_to_id(features, intervals):


def id_to_features(id, intervals):
  features = 7 * [0]
  features = 6 * [0]

  # Correct for null
  id -= 1

  for k in range(0, 7 - 1):
    # print(7-k-1, id)
    features[7 - k - 1] = id // intervals[7 - k - 1]
    id -= features[7 - k - 1] * intervals[7 - k - 1]
  for k in range(0, 6 - 1):
    # print(6-k-1, id)
    features[6 - k - 1] = id // intervals[6 - k - 1]
    id -= features[6 - k - 1] * intervals[6 - k - 1]
  # Correct for last one
  features[0] = id
  return features
@@ -114,7 +111,10 @@ def atom_to_id(atom):
  return features_to_id(features, intervals)


def atom_features(atom, bool_id_feat=False, explicit_H=False, use_chirality=False):
def atom_features(atom,
                  bool_id_feat=False,
                  explicit_H=False,
                  use_chirality=False):
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
  else:
@@ -179,21 +179,30 @@ def atom_features(atom, bool_id_feat=False, explicit_H=False, use_chirality=Fals
      results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                [0, 1, 2, 3, 4])
    if use_chirality:
      results = results + one_of_k_encoding_unk(atom.GetChiralTag(),
                [Chem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.ChiralType.CHI_TETRAHEDRAL_CCW,
                 Chem.ChiralType.CHI_UNSPECIFIED]) 
      try:
        results = results + one_of_k_encoding_unk(
            atom.GetProp('_CIPCode'),
            ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
      except:
        results = results + [False, False
                            ] + [atom.HasProp('_ChiralityPossible')]

    return np.array(results)


def bond_features(bond):
def bond_features(bond, use_chirality=False):
  bt = bond.GetBondType()
  return np.array([
  bond_feats = [
      bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
      bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(),
      bond.IsInRing()
  ])
  ]
  if use_chirality:
    bond_feats = bond_feats + one_of_k_encoding_unk(
        str(bond.GetStereo(),
            ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]))
  return np.array(bond_feats)


def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
@@ -278,7 +287,8 @@ class ConvMolFeaturizer(Featurizer):
  def _featurize(self, mol):
    """Encodes mol as a ConvMol object."""
    # Get the node features
    idx_nodes = [(a.GetIdx(), atom_features(a, use_chirality=self.use_chirality)) for a in mol.GetAtoms()]
    idx_nodes = [(a.GetIdx(), atom_features(
        a, use_chirality=self.use_chirality)) for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))

@@ -310,7 +320,8 @@ class ConvMolFeaturizer(Featurizer):
class WeaveFeaturizer(Featurizer):
  name = ['weave_mol']

  def __init__(self, graph_distance=True, explicit_H=False, use_chirality=False):
  def __init__(self, graph_distance=True, explicit_H=False,
               use_chirality=False):
    # Distance is either graph distance(True) or Euclidean distance(False,
    # only support datasets providing Cartesian coordinates)
    self.graph_distance = graph_distance
@@ -324,7 +335,11 @@ class WeaveFeaturizer(Featurizer):
  def _featurize(self, mol):
    """Encodes mol as a WeaveMol object."""
    # Atom features
    idx_nodes = [(a.GetIdx(), atom_features(a, explicit_H=self.explicit_H, use_chirality=self.use_chirality))
    idx_nodes = [(a.GetIdx(),
                  atom_features(
                      a,
                      explicit_H=self.explicit_H,
                      use_chirality=self.use_chirality))
                 for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))
@@ -336,7 +351,8 @@ class WeaveFeaturizer(Featurizer):
    edge_list = {}
    for b in mol.GetBonds():
      edge_list[tuple(sorted([b.GetBeginAtomIdx(),
                              b.GetEndAtomIdx()]))] = bond_features(b)
                              b.GetEndAtomIdx()]))] = bond_features(
                                  b, use_chirality=self.use_chirality)

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]