Unverified Commit d6368781 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2597 from Suzukazole/rxntransformer

 Reaction Split Transformer
parents bd910d32 54ed4814
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -21,4 +21,5 @@ from deepchem.trans.transformers import ImageTransformer
from deepchem.trans.transformers import DataTransforms
from deepchem.trans.transformers import Transformer
from deepchem.trans.transformers import FlatteningTransformer
from deepchem.trans.transformers import RxnSplitTransformer
from deepchem.trans.duplicate import DuplicateBalancingTransformer
+55 −0
Original line number Diff line number Diff line
import unittest
import numpy as np

from deepchem.trans.transformers import RxnSplitTransformer

reactions: np.ndarray = np.array(
    [
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]>CC(C)CC(=O)c1ccc(O)nc1",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>>CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ],
    dtype=object)

split: np.ndarray = np.array(
    [[
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]",
        "CC(C)CC(=O)c1ccc(O)nc1"
    ], [
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ]],
    dtype=object)

sep: np.ndarray = np.array(
    [[
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1.C1CCOC1.[Cl-]>",
        "CC(C)CC(=O)c1ccc(O)nc1"
    ], [
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ]],
    dtype=object)


class TestRxnSplitTransformer(unittest.TestCase):
  """
  Tests the Reaction split transformer for the source/target splitting and
  for the reagent mixing operation.
  """

  def test_split(self):
    """Tests the source/target split from an input reaction SMILES."""
    trans = RxnSplitTransformer(sep_reagent=True)
    split_reactions = trans.transform_array(
        X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
    assert split_reactions[0].shape == (2, 2)
    assert (split_reactions[0] == split).all()

  def test_mixing(self):
    """Tests the reagent - reactant mixing toggle."""

    trans = RxnSplitTransformer(sep_reagent=False)
    split_reactions = trans.transform_array(
        X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
    assert split_reactions[0].shape == (2, 2)
    assert (split_reactions[0] == sep).all()
+113 −0
Original line number Diff line number Diff line
@@ -2469,3 +2469,116 @@ class DataTransforms(object):
    image = Image.fromarray(self.Image)
    image = image.filter(ImageFilter.MedianFilter(size=size))
    return np.array(image)


class RxnSplitTransformer(Transformer):
  """Splits the reaction SMILES input into the source and target strings
  required for machine translation tasks.

  The input is expected to be in the form reactant>reagent>product. The source
  string would be reactants>reagents and the target string would be the products.

  The transformer can also separate the reagents from the reactants for a mixed
  training mode. During mixed training, the source string is transformed from
  reactants>reagent to reactants.reagent> . This can be toggled (default True)
  by setting the value of sep_reagent while calling the transformer.

  Examples
  --------
  >>> # When mixed training is toggled.
  >>> import numpy as np
  >>> from deepchem.trans.transformers import RxnSplitTransformer
  >>> reactions = np.array(
    [
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]>CC(C)CC(=O)c1ccc(O)nc1",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>>CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ],
    dtype=object)
  >>> trans = RxnSplitTransformer(sep_reagent=True)
  >>> split_reactions = trans.transform_array(X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
  >>> split_reactions
  (array([['CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]',
           'CC(C)CC(=O)c1ccc(O)nc1'],
          ['CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>',
           'CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21']], dtype='<U51'), array([], dtype=float64), array([], dtype=float64), array([], dtype=float64))

  When mixed training is disabled, you get the following outputs:

  >>> trans_disable = RxnSplitTransformer(sep_reagent=False)
  >>> split_reactions = trans_disable.transform_array(X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
  >>> split_reactions
  (array([['CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1.C1CCOC1.[Cl-]>',
           'CC(C)CC(=O)c1ccc(O)nc1'],
          ['CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>',
           'CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21']], dtype='<U51'), array([], dtype=float64), array([], dtype=float64), array([], dtype=float64))

  Note
  ----
  This class only transforms the feature field of a reaction dataset like USPTO.
  """

  def __init__(self,
               sep_reagent: bool = True,
               dataset: Optional[Dataset] = None):
    """Initializes the Reaction split Transformer.

    Parameters
    ----------
    sep_reagent: bool, optional (default True)
      To separate the reagent and reactants for training.
    dataset: dc.data.Dataset object, optional (default None)
      Dataset to be transformed.
    """

    self.sep_reagent = sep_reagent
    super(RxnSplitTransformer, self).__init__(transform_X=True, dataset=dataset)

  def transform_array(
      self, X: np.ndarray, y: np.ndarray, w: np.ndarray,
      ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Transform the data in a set of (X, y, w, ids) arrays.

    Parameters
    ----------
    X: np.ndarray
      Array of features(the reactions)
    y: np.ndarray
      Array of labels
    w: np.ndarray
      Array of weights.
    ids: np.ndarray
      Array of weights.

    Returns
    -------
    Xtrans: np.ndarray
      Transformed array of features
    ytrans: np.ndarray
      Transformed array of labels
    wtrans: np.ndarray
      Transformed array of weights
    idstrans: np.ndarray
      Transformed array of ids
    """

    reactant = list(map(lambda x: x.split('>')[0], X))
    reagent = list(map(lambda x: x.split('>')[1], X))
    product = list(map(lambda x: x.split('>')[2], X))

    if self.sep_reagent:
      source = [x + '>' + y for x, y in zip(reactant, reagent)]
    else:
      source = [
          x + '.' + y + '>' if y else x + '>' + y
          for x, y in zip(reactant, reagent)
      ]

    target = product

    X = np.column_stack((source, target))

    return (X, y, w, ids)

  def untransform(self, z):
    """Not Implemented."""
    raise NotImplementedError("Cannot untransform the source/target split.")
+7 −0
Original line number Diff line number Diff line
@@ -108,6 +108,13 @@ DAGTransformer
  :members:
  :inherited-members:

RxnSplitTransformer
^^^^^^^^^^^^^^^^^^^

.. autoclass:: deepchem.trans.RxnSplitTransformer
  :members:
  :inherited-members:

Base Transformer (for develop)
-------------------------------