Commit 86eb315d authored by Milosz Grabski's avatar Milosz Grabski
Browse files

Update molgan_featurizer.py

parent 62db1978
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -124,13 +124,12 @@ class MolGanFeaturizer(MolecularFeaturizer):
    if self.kekulize:
      Chem.Kekulize(mol)

    A = np.zeros(
        shape=(self.max_atom_count, self.max_atom_count), dtype=np.float32)
    A = np.zeros(shape=(self.max_atom_count, self.max_atom_count),
                 dtype=np.float32)
    bonds = mol.GetBonds()

    begin, end = [b.GetBeginAtomIdx() for b in bonds], [
        b.GetEndAtomIdx() for b in bonds
    ]
    begin, end = [b.GetBeginAtomIdx() for b in bonds
                 ], [b.GetEndAtomIdx() for b in bonds]
    bond_type = [self.bond_encoder[b.GetBondType()] for b in bonds]

    A[begin, end] = bond_type
@@ -191,8 +190,8 @@ class MolGanFeaturizer(MolecularFeaturizer):

    for start, end in zip(*np.nonzero(edge_labels)):
      if start > end:
        mol.AddBond(
            int(start), int(end), self.bond_decoder[edge_labels[start, end]])
        mol.AddBond(int(start), int(end), self.bond_decoder[edge_labels[start,
                                                                        end]])

    if sanitize:
      try:
@@ -213,7 +212,8 @@ class MolGanFeaturizer(MolecularFeaturizer):

    return mol

  def defeaturize(self, graphs: OneOrMany[GraphMatrix],
  def defeaturize(self,
                  graphs: OneOrMany[GraphMatrix],
                  log_every_n: int = 1000) -> np.ndarray:
    """
    Calculates molecules from corresponding GraphMatrix objects.