Commit ad60fae5 authored by miaecle's avatar miaecle
Browse files

temporary save

parent 5f1e81f6
Loading
Loading
Loading
Loading
+47 −0
Original line number Diff line number Diff line
@@ -132,6 +132,21 @@ def bond_features(bond):
                   bond.GetIsConjugated(),
                   bond.IsInRing()])
  
def pair_features(mol, canon_adj_list):
  features = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms(), 12))
  for a1 in mol.GetAtoms():
    a1_id = a1.GetIdx()
    for a2 in mol.GetAtoms():
      a2_id = a2.GetIdx()
      if a2_id in canon_adj_list[a1_id]:
        bt = bond_features(mol.GetBondBetweenAtoms(a1_id, a2_id))
  return np.array([bt == Chem.rdchem.BondType.SINGLE,
                   bt == Chem.rdchem.BondType.DOUBLE,
                   bt == Chem.rdchem.BondType.TRIPLE,
                   bt == Chem.rdchem.BondType.AROMATIC,
                   bond.GetIsConjugated(),
                   bond.IsInRing()])

class ConvMolFeaturizer(Featurizer):

  name = ['conv_mol']
@@ -160,3 +175,35 @@ class ConvMolFeaturizer(Featurizer):
      canon_adj_list[edge[1]].append(edge[0])

    return ConvMol(nodes, canon_adj_list)
    

class WeaveFeaturizer(Featurizer):

  name = ['weave_mol']
  def __init__(self):
    # Since ConvMol is an object and not a numpy array, need to set dtype to
    # object.
    self.dtype = object

  def _featurize(self, mol):
    """Encodes mol as a ConvMol object."""
    # Atom features
    idx_nodes = [(a.GetIdx(), atom_features(a)) for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))

    # Stack nodes into an array
    nodes = np.vstack(nodes)

    # Get bond lists with reverse edges included
    edge_list = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
    for edge in edge_list:
      canon_adj_list[edge[0]].append(edge[1])
      canon_adj_list[edge[1]].append(edge[0])
    
    pairs = pair_features(mol, canon_adj_list)
      
    return ConvMol(nodes, canon_adj_list)
 No newline at end of file