Commit 7432dcca authored by peastman's avatar peastman
Browse files

IRVTransformer extends Transformer

parent 285c9081
Loading
Loading
Loading
Loading
+8 −6
Original line number Diff line number Diff line
@@ -1253,8 +1253,8 @@ class CoulombFitTransformer(Transformer):
        "Cannot untransform datasets with FitTransformer.")


class IRVTransformer():
  """Performs transform from ECFP to IRV features(K nearest neibours)."""
class IRVTransformer(Transformer):
  """Performs transform from ECFP to IRV features(K nearest neighbors)."""

  def __init__(self, K, n_tasks, dataset, transform_y=False, transform_x=False):
    """Initializes IRVTransformer.
@@ -1273,8 +1273,7 @@ class IRVTransformer():
    self.K = K
    self.y = dataset.y
    self.w = dataset.w
    super(IRVTransformer, self).__init__(
        transform_X=transform_x, transform_y=transform_y)
    super(IRVTransformer, self).__init__(transform_X=True)

  def realize(self, similarity, y, w):
    """find samples with top ten similarity values in the reference dataset
@@ -1395,7 +1394,7 @@ class IRVTransformer():
      del result
    return all_result

  def transform(self, dataset):
  def transform(self, dataset, parallel=False, out_dir=None, **kwargs):
    X_length = dataset.X.shape[0]
    X_trans = []
    for count in range(X_length // 5000 + 1):
@@ -1403,7 +1402,10 @@ class IRVTransformer():
          self.X_transform(
              dataset.X[count * 5000:min((count + 1) * 5000, X_length), :]))
    X_trans = np.concatenate(X_trans, axis=0)
    if out_dir is None:
      return NumpyDataset(X_trans, dataset.y, dataset.w, ids=None)
    return DiskDataset.from_numpy(
        X_trans, dataset.y, dataset.w, data_dir=out_dir)

  def untransform(self, z):
    raise NotImplementedError(