Unverified Commit 7da392a6 authored by Suzukazole's avatar Suzukazole
Browse files

update docs

parent 41c8989f
Loading
Loading
Loading
Loading
+15 −17
Original line number Diff line number Diff line
@@ -2487,7 +2487,16 @@ class RxnSplitTransformer(Transformer):

  """

  def __init__(self, sep_reagent: bool, dataset: Optional[Dataset] = None):
  def __init__(self, sep_reagent: bool = True, dataset: Optional[Dataset] = None):
    """Initializes the Reaction split Transformer.

    Parameters
    ----------
    sep_reagent: bool, optional (default True)
      To separate the reagent and reactants for training.
    dataset: dc.data.Dataset object, optional (default None)
      Dataset to be transformed. 
    """
    # the transformer would have to split the source and target sequences
    # would also consider adding the option of separating the reagent here.

@@ -2497,10 +2506,10 @@ class RxnSplitTransformer(Transformer):
  def transform_array(
      self, X: np.ndarray, y: np.ndarray, w: np.ndarray,
      ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Transform the data in a set of (X, y, w) arrays.
    """Transform the data in a set of (X, y, w, ids) arrays.

    Parameters
    -----
    ----------
    X: np.ndarray
      Array of features(the reactions)
    y: np.ndarray
@@ -2540,17 +2549,6 @@ class RxnSplitTransformer(Transformer):

    return (X, y, w, ids)

  def untransform(self, z: np.ndarray) -> np.ndarray:
    """
    Undo transformation on provided data.
    Parameters
    ----------
    z: np.ndarray
      Array to transform back
    Returns
    -------
    np.ndarray
      Array with normalization undone.
    """
    # TODO
    return super().untransform(z)
  def untransform(self, z):
    """Not Implemented."""
    raise NotImplementedError("Cannot untransform the source/target split.")