Commit 224352b8 authored by miaecle's avatar miaecle
Browse files

Merge remote-tracking branch 'remotes/mine/IRV'

parents 719a506a ef2b0b0e
Loading
Loading
Loading
Loading
+64 −0
Original line number Diff line number Diff line
@@ -568,3 +568,67 @@ class CoulombFitTransformer():
    raise NotImplementedError(
      "Cannot untransform datasets with FitTransformer.")

 class IRVTransformer():
  """Performs randomization and binarization operations on batches of Coulomb Matrix features during fit."""
  def __init__(self, K, dataset):

    """Initializes CoulombFitTransformer.
    Parameters:
    ----------
    dataset: dc.data.Dataset object

    K: number of nearest neighbours that count
    
    """
    self.X = dataset.X
    self.n_samples = X.shape[0]
    self.K = K
    self.y = dataset.y
    self.w = dataset.w

  @staticmethod
  def similarity(x, y):
    if x.shape != y.shape:
      raise ValueError(
      "Similarity measurement requires same size")
    return np.sum(np.min([x,y], axis=0))/np.sum(np.max([x,y], axis=0))

  def realize(self, X_target):
    """Randomize features. """
    def _realize_(x_target):
      similar_xs = []
      similar_inds = []
      threshold = 0
      for count, x in enumerate(self.X):
        similar = IRVTransformer.similarity(x, x_target)
        if similar >= 1:
          continue
        if len(similar_xs) < self.K-1:
          similar_xs.append(similar)
          similar_inds.append(count)
        if len(similar_xs) == self.K-1:
          if self.w[count]
          similar_xs.append(similar)
          similar_inds.append(count)
          threshold = min(similar_xs)
        elif similar > threshold:
          ind = similar_xs.index(threshold)
          similar_xs.pop(ind)
          similar_inds.pop(ind)
          similar_xs.append(similar)
          similar_inds.append(count)
          threshold = min(similar_xs)
      return similar_xs.extend(similar_inds)
    return np.array([_realize_(z) for z in X_target])
      
  def X_transform(self, X_target):
    X_target = self.realize(X_target)
    return X_target

  def transform(self, dataset):
    raise NotImplementedError(
      "Cannot transform datasets with FitTransformer")

  def untransform(self, z):
    raise NotImplementedError(
      "Cannot untransform datasets with FitTransformer.")