Commit 7b584390 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

yapf rerun

parent 3c81a59c
Loading
Loading
Loading
Loading
+9 −8
Original line number Diff line number Diff line
@@ -123,12 +123,13 @@ class MolGanFeaturizer(MolecularFeaturizer):
    if self.kekulize:
      Chem.Kekulize(mol)

    A = np.zeros(shape=(self.max_atom_count, self.max_atom_count),
                 dtype=np.float32)
    A = np.zeros(
        shape=(self.max_atom_count, self.max_atom_count), dtype=np.float32)
    bonds = mol.GetBonds()

    begin, end = [b.GetBeginAtomIdx() for b in bonds
                 ], [b.GetEndAtomIdx() for b in bonds]
    begin, end = [b.GetBeginAtomIdx() for b in bonds], [
        b.GetEndAtomIdx() for b in bonds
    ]
    bond_type = [self.bond_encoder[b.GetBondType()] for b in bonds]

    A[begin, end] = bond_type
@@ -189,8 +190,8 @@ class MolGanFeaturizer(MolecularFeaturizer):

    for start, end in zip(*np.nonzero(edge_labels)):
      if start > end:
        mol.AddBond(int(start), int(end), self.bond_decoder[edge_labels[start,
                                                                        end]])
        mol.AddBond(
            int(start), int(end), self.bond_decoder[edge_labels[start, end]])

    if sanitize:
      try:
@@ -211,8 +212,7 @@ class MolGanFeaturizer(MolecularFeaturizer):

    return mol

  def defeaturize(self,
                  graphs: OneOrMany[GraphMatrix],
  def defeaturize(self, graphs: OneOrMany[GraphMatrix],
                  log_every_n: int = 1000) -> np.ndarray:
    """
    Calculates molecules from corresponding GraphMatrix objects.
@@ -223,6 +223,7 @@ class MolGanFeaturizer(MolecularFeaturizer):
      GraphMatrix object or corresponding iterable
    log_every_n: int, default 1000
      Logging messages reported every `log_every_n` samples.
      
    Returns
    -------
    features: np.ndarray