Commit 4ec24adf authored by miaecle's avatar miaecle
Browse files

IRV handle large dataset

parent 18e74be0
Loading
Loading
Loading
Loading
+25 −2
Original line number Diff line number Diff line
@@ -682,12 +682,35 @@ class IRVTransformer():
    """
    X_target2 = []
    n_features = X_target.shape[1]
    similarity = np.matmul(X_target, np.transpose(self.X)) / (
        n_features - np.matmul(1 - X_target, np.transpose(1 - self.X)))
    similarity = matrix_mul(X_target, np.transpose(self.X)) / (
        n_features - matrix_mul(1 - X_target, np.transpose(1 - self.X)))
    for i in range(self.n_tasks):
      X_target2.append(self.realize(similarity, self.y[:, i], self.w[:, i]))
    return np.concatenate([z for z in np.array(X_target2)], axis=1)

  @staticmethod
  def matrix_mul(X1, X2, shard_size=1000):
    X1_shape = X1.shape
    X2_shape = X2.shape
    assert X1_shape[1] == X2_shape[0]
    X1_iter = X1_shape[0]//shard_size + 1
    X2_iter = X2_shape[1]//shard_size + 1
    all_result = np.zeros((1,))
    for X1_id in range(X1_iter):
      result = np.zeros((1,))
      for X2_id in range(X2_iter):
        partial_result = np.matmul(X1[X1_id*shard_size:min((X1_id+1)*shard_size, X1_shape[0]),:],
                                   X2[:, X2_id*shard_size:min((X2_id+1)*shard_size, X2_shape[1])])
        if result.size == 1:
          result = partial_result
        else:
          result = np.concatenate((result, partial_result), axis=1)
      if all_result.size == 1:
        all_result = result
      else:
        all_result = np.concatenate((all_result, result), axis=0)
    return all_result    

  def transform(self, dataset):
    X_trans = self.X_transform(dataset.X)
    return NumpyDataset(X_trans, dataset.y, dataset.w, ids=None)
+3 −2
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import numpy as np
import shutil
import deepchem as dc

def load_sampl(featurizer='ECFP', split='index'):
def load_sampl(featurizer='ECFP', split='random', frac_train=0.8):
  """Load SAMPL datasets."""
  # Featurize SAMPL dataset
  print("About to featurize SAMPL dataset.")
@@ -39,5 +39,6 @@ def load_sampl(featurizer='ECFP', split='index'):
               'random': dc.splits.RandomSplitter(),
               'scaffold': dc.splits.ScaffoldSplitter()}
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)
  train, valid, test = splitter.train_valid_test_split(dataset, frac_train=frac_train,
                             frac_valid=1-frac_train, frac_test=0.)
  return SAMPL_tasks, (train, valid, test), transformers