Unverified Commit d6368781 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2597 from Suzukazole/rxntransformer

 Reaction Split Transformer
parents bd910d32 54ed4814
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -21,4 +21,5 @@ from deepchem.trans.transformers import ImageTransformer
from deepchem.trans.transformers import DataTransforms
from deepchem.trans.transformers import DataTransforms
from deepchem.trans.transformers import Transformer
from deepchem.trans.transformers import Transformer
from deepchem.trans.transformers import FlatteningTransformer
from deepchem.trans.transformers import FlatteningTransformer
from deepchem.trans.transformers import RxnSplitTransformer
from deepchem.trans.duplicate import DuplicateBalancingTransformer
from deepchem.trans.duplicate import DuplicateBalancingTransformer
+55 −0
Original line number Original line Diff line number Diff line
import unittest
import numpy as np

from deepchem.trans.transformers import RxnSplitTransformer

reactions: np.ndarray = np.array(
    [
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]>CC(C)CC(=O)c1ccc(O)nc1",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>>CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ],
    dtype=object)

split: np.ndarray = np.array(
    [[
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]",
        "CC(C)CC(=O)c1ccc(O)nc1"
    ], [
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ]],
    dtype=object)

sep: np.ndarray = np.array(
    [[
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1.C1CCOC1.[Cl-]>",
        "CC(C)CC(=O)c1ccc(O)nc1"
    ], [
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ]],
    dtype=object)


class TestRxnSplitTransformer(unittest.TestCase):
  """
  Tests the Reaction split transformer for the source/target splitting and
  for the reagent mixing operation.
  """

  def test_split(self):
    """Tests the source/target split from an input reaction SMILES."""
    trans = RxnSplitTransformer(sep_reagent=True)
    split_reactions = trans.transform_array(
        X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
    assert split_reactions[0].shape == (2, 2)
    assert (split_reactions[0] == split).all()

  def test_mixing(self):
    """Tests the reagent - reactant mixing toggle."""

    trans = RxnSplitTransformer(sep_reagent=False)
    split_reactions = trans.transform_array(
        X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
    assert split_reactions[0].shape == (2, 2)
    assert (split_reactions[0] == sep).all()
+113 −0
Original line number Original line Diff line number Diff line
@@ -2469,3 +2469,116 @@ class DataTransforms(object):
    image = Image.fromarray(self.Image)
    image = Image.fromarray(self.Image)
    image = image.filter(ImageFilter.MedianFilter(size=size))
    image = image.filter(ImageFilter.MedianFilter(size=size))
    return np.array(image)
    return np.array(image)


class RxnSplitTransformer(Transformer):
  """Splits the reaction SMILES input into the source and target strings
  required for machine translation tasks.

  The input is expected to be in the form reactant>reagent>product. The source
  string would be reactants>reagents and the target string would be the products.

  The transformer can also separate the reagents from the reactants for a mixed
  training mode. During mixed training, the source string is transformed from
  reactants>reagent to reactants.reagent> . This can be toggled (default True)
  by setting the value of sep_reagent while calling the transformer.

  Examples
  --------
  >>> # When mixed training is toggled.
  >>> import numpy as np
  >>> from deepchem.trans.transformers import RxnSplitTransformer
  >>> reactions = np.array(
    [
        "CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]>CC(C)CC(=O)c1ccc(O)nc1",
        "CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>>CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21"
    ],
    dtype=object)
  >>> trans = RxnSplitTransformer(sep_reagent=True)
  >>> split_reactions = trans.transform_array(X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
  >>> split_reactions
  (array([['CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1>C1CCOC1.[Cl-]',
           'CC(C)CC(=O)c1ccc(O)nc1'],
          ['CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>',
           'CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21']], dtype='<U51'), array([], dtype=float64), array([], dtype=float64), array([], dtype=float64))

  When mixed training is disabled, you get the following outputs:

  >>> trans_disable = RxnSplitTransformer(sep_reagent=False)
  >>> split_reactions = trans_disable.transform_array(X=reactions, y=np.array([]), w=np.array([]), ids=np.array([]))
  >>> split_reactions
  (array([['CC(C)C[Mg+].CON(C)C(=O)c1ccc(O)nc1.C1CCOC1.[Cl-]>',
           'CC(C)CC(=O)c1ccc(O)nc1'],
          ['CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(N)cc3)cc21.O=CO>',
           'CCn1cc(C(=O)O)c(=O)c2cc(F)c(-c3ccc(NC=O)cc3)cc21']], dtype='<U51'), array([], dtype=float64), array([], dtype=float64), array([], dtype=float64))

  Note
  ----
  This class only transforms the feature field of a reaction dataset like USPTO.
  """

  def __init__(self,
               sep_reagent: bool = True,
               dataset: Optional[Dataset] = None):
    """Initializes the Reaction split Transformer.

    Parameters
    ----------
    sep_reagent: bool, optional (default True)
      To separate the reagent and reactants for training.
    dataset: dc.data.Dataset object, optional (default None)
      Dataset to be transformed.
    """

    self.sep_reagent = sep_reagent
    super(RxnSplitTransformer, self).__init__(transform_X=True, dataset=dataset)

  def transform_array(
      self, X: np.ndarray, y: np.ndarray, w: np.ndarray,
      ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Transform the data in a set of (X, y, w, ids) arrays.

    Parameters
    ----------
    X: np.ndarray
      Array of features(the reactions)
    y: np.ndarray
      Array of labels
    w: np.ndarray
      Array of weights.
    ids: np.ndarray
      Array of weights.

    Returns
    -------
    Xtrans: np.ndarray
      Transformed array of features
    ytrans: np.ndarray
      Transformed array of labels
    wtrans: np.ndarray
      Transformed array of weights
    idstrans: np.ndarray
      Transformed array of ids
    """

    reactant = list(map(lambda x: x.split('>')[0], X))
    reagent = list(map(lambda x: x.split('>')[1], X))
    product = list(map(lambda x: x.split('>')[2], X))

    if self.sep_reagent:
      source = [x + '>' + y for x, y in zip(reactant, reagent)]
    else:
      source = [
          x + '.' + y + '>' if y else x + '>' + y
          for x, y in zip(reactant, reagent)
      ]

    target = product

    X = np.column_stack((source, target))

    return (X, y, w, ids)

  def untransform(self, z):
    """Not Implemented."""
    raise NotImplementedError("Cannot untransform the source/target split.")
+7 −0
Original line number Original line Diff line number Diff line
@@ -108,6 +108,13 @@ DAGTransformer
  :members:
  :members:
  :inherited-members:
  :inherited-members:


RxnSplitTransformer
^^^^^^^^^^^^^^^^^^^

.. autoclass:: deepchem.trans.RxnSplitTransformer
  :members:
  :inherited-members:

Base Transformer (for develop)
Base Transformer (for develop)
-------------------------------
-------------------------------