Commit 1195101e authored by Joseph Gomes's avatar Joseph Gomes
Browse files

Add CoulombFitTransformer

parent 68d6faa1
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -15,3 +15,4 @@ from deepchem.trans.transformers import CDFTransformer
from deepchem.trans.transformers import PowerTransformer
from deepchem.trans.transformers import CoulombRandomizationFitTransformer
from deepchem.trans.transformers import NormalizationFitTransformer
from deepchem.trans.transformers import CoulombFitTransformer
+47 −0
Original line number Diff line number Diff line
@@ -696,3 +696,50 @@ class NormalizationFitTransformer():
  def untransform(self, z):
    raise NotImplementedError(
      "Cannot untransform datasets with FitTransformer.")

class CoulombFitTransformer():

  def __init__(self, X, num_atoms=23):
    self.step = 1.0
    self.noise = 1.0
    self.triuind = (np.arange(num_atoms)[:,np.newaxis] <= np.arange(num_atoms)[np.newaxis,:]).flatten()
    self.max = 0
    for _ in range(10): self.max = np.maximum(self.max,self.realize(X).max(axis=0))
    X = self.expand(self.realize(X))
    self.nbout = X.shape[1]
    self.mean = X.mean(axis=0)
    self.std = (X - self.mean).std()

  def realize(self, X):
    """Randomize features. """
    def _realize_(x):
      inds = np.argsort(-(x**2).sum(axis=0)**.5+np.random.normal(0,self.noise,x[0].shape))
      x = x[inds,:][:,inds]*1
      x = x.flatten()[self.triuind]
      return x
    return np.array([_realize_(z) for z in X])

  def normalize(self, X):
    """Normalize features. """
    return (X-self.mean)/self.std

  def expand(self, X):
    """Binarize features. """
    Xexp = []
    for i in range(X.shape[1]):
      for k in np.arange(0,self.max[i]+self.step,self.step):
        Xexp += [np.tanh((X[:,i]-k)/self.step)]
    return np.array(Xexp).T
      
  def X_transform(self, X):
    X = self.normalize(self.expand(self.realize(X)))
    return X

  def transform(self, dataset):
    raise NotImplementedError(
      "Cannot transform datasets with FitTransformer")

  def untransform(self, z):
    raise NotImplementedError(
      "Cannot untransform datasets with FitTransformer.")