Unverified Commit 70af35b8 authored by Suzukazole's avatar Suzukazole
Browse files

fix tests/docs

parent aa2d5237
Loading
Loading
Loading
Loading
+38 −16
Original line number Diff line number Diff line
@@ -10,24 +10,46 @@ reactions: np.ndarray = np.array(
    ],
    dtype=object)

split: np.ndarray = np.array(
    [[
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]",
        "CC(C)CC(=O)c1ccc(O)nc1"
    ], [
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ]],
    dtype=object)

sep: np.ndarray = np.array(
    [[
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1.C1CCOC1.[Cl-]>",
        "CC(C)CC(=O)c1ccc(O)nc1"
    ], [
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ]],
    dtype=object)


class TestRxnSplitTransformer(unittest.TestCase):
  """Tests the Reaction split transformer for the source/target splitting and
  """
  Tests the Reaction split transformer for the source/target splitting and
  for the reagent mixing operation.
  """

  def test_split(self):
    """
        Tests the source/target split from an input reaction SMILES.
        """
    """Tests the source/target split from an input reaction SMILES."""
    trans = RxnSplitTransformer(sep_reagent=True)
    split_reactions = trans.transform_array(X=reactions, y=[], w=[], ids=[])
    assert split_reactions[0].shape == (3, 2)
    # Should check for equality of split_reactions and sep_split!

  # def test_mixing(self):
  #    """
  #    Tests the reagent - reactant mixing option.
  #    """
  # WIP
  #    pass
    split_reactions = trans.transform_array(
        X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
    assert split_reactions[0].shape == (2, 2)
    assert (split_reactions[0] == split).all()

  def test_mixing(self):
    """Tests the reagent - reactant mixing toggle."""

    trans = RxnSplitTransformer(sep_reagent=False)
    split_reactions = trans.transform_array(
        X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
    assert split_reactions[0].shape == (2, 2)
    assert (split_reactions[0] == sep).all()
+2 −4
Original line number Diff line number Diff line
@@ -2502,7 +2502,7 @@ class RxnSplitTransformer(Transformer):
          ['CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>',
           'CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21']], dtype='<U51'), array([], dtype=float64), array([], dtype=float64), array([], dtype=float64))

  When mixed training is diabled, you get the following outputs
  When mixed training is disabled, you get the following outputs:

  >>> trans_disable = RxnSplitTransformer(sep_reagent=False)
  >>> split_reactions = trans_disable.transform_array(X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
@@ -2529,8 +2529,6 @@ class RxnSplitTransformer(Transformer):
    dataset: dc.data.Dataset object, optional (default None)
      Dataset to be transformed.
    """
    # the transformer would have to split the source and target sequences
    # would also consider adding the option of separating the reagent here.

    self.sep_reagent = sep_reagent
    super(RxnSplitTransformer, self).__init__(transform_X=True, dataset=dataset)