Unverified Commit e3d7798e authored by Suzukazole's avatar Suzukazole
Browse files

add transform docstring

parent f46d8a0a
Loading
Loading
Loading
Loading
+25 −1
Original line number Diff line number Diff line
@@ -2479,6 +2479,30 @@ class RxnSplitTransformer(Transformer):
  def transform_array(
      self, X: np.ndarray, y: np.ndarray, w: np.ndarray,
      ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Transform the data in a set of (X, y, w) arrays.

    Parameters
    -----
    X: np.ndarray
      Array of features(the reactions)
    y: np.ndarray
      Array of labels
    w: np.ndarray
      Array of weights.
    ids: np.ndarray
      Array of weights.

    Returns
    -------
    Xtrans: np.ndarray
      Transformed array of features
    ytrans: np.ndarray
      Transformed array of labels
    wtrans: np.ndarray
      Transformed array of weights
    idstrans: np.ndarray
      Transformed array of ids
    """
    
    source = list(map(lambda x: x.split('>')[0] + '>' + x.split('>')[1], X))

@@ -2488,5 +2512,5 @@ class RxnSplitTransformer(Transformer):

    return (X, y, w, ids)

  def untransform(self, transformed):
  def untransform(self, transformed: np.ndarray) -> np.ndarray:
    return super().untransform(transformed)