Commit 7af7edae authored by Vignesh's avatar Vignesh
Browse files

Fixes to chirality for Weave

parent 26559014
Loading
Loading
Loading
Loading
+14 −11
Original line number Diff line number Diff line
@@ -73,6 +73,8 @@ reference_lists = [
]

intervals = get_intervals(reference_lists)
possible_bond_stereo = ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]
bond_fdim_base = 6


def get_feature_list(atom):
@@ -210,8 +212,7 @@ def bond_features(bond, use_chirality=False):
  ]
  if use_chirality:
    bond_feats = bond_feats + one_of_k_encoding_unk(
        str(bond.GetStereo()),
        ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])
        str(bond.GetStereo()), possible_bond_stereo)
  return np.array(bond_feats)


@@ -396,6 +397,10 @@ class WeaveFeaturizer(Featurizer):
    self.explicit_H = explicit_H
    # If uses use_chirality
    self.use_chirality = use_chirality
    if self.use_chirality:
      self.bt_len = bond_fdim_base + len(possible_bond_stereo)
    else:
      self.bt_len = bond_fdim_base

  def _featurize(self, mol):
    """Encodes mol as a WeaveMol object."""
@@ -425,14 +430,12 @@ class WeaveFeaturizer(Featurizer):
      canon_adj_list[edge[0]].append(edge[1])
      canon_adj_list[edge[1]].append(edge[0])

    bt_len = len(list(edge_list.values())[0])

    # Calculate pair features
    pairs = pair_features(
        mol,
        edge_list,
        canon_adj_list,
        bt_len=bt_len,
        bt_len=self.bt_len,
        graph_distance=self.graph_distance)

    return WeaveMol(nodes, pairs)