Unverified Commit 5b20c8f9 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1005 from lilleswing/feat_trans_master

Featurization Transformer
parents d0fac17d 0dec77d2
Loading
Loading
Loading
Loading
+21 −6
Original line number Diff line number Diff line
"""
Tests for transformer objects. 
"""
from __future__ import print_function
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from deepchem.molnet import load_delaney
from deepchem.trans.transformers import FeaturizationTransformer

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"
@@ -14,7 +17,6 @@ import unittest
import numpy as np
import pandas as pd
import deepchem as dc
import numpy.random as random


class TestTransformers(unittest.TestCase):
@@ -468,3 +470,16 @@ class TestTransformers(unittest.TestCase):
    assert np.allclose(test_dataset_trans.X[0, :10], sims[:10])
    assert np.allclose(test_dataset_trans.X[0, 10:20], [0] * 10)
    assert not np.isclose(dataset_trans.X[0, 0], 1.)

  def test_featurization_transformer(self):
    fp_size = 2048
    tasks, all_dataset, transformers = load_delaney('Raw')
    train = all_dataset[0]
    transformer = FeaturizationTransformer(
        transform_X=True,
        dataset=train,
        featurizer=dc.feat.CircularFingerprint(size=fp_size))
    new_train = transformer.transform(train)

    self.assertEqual(new_train.y.shape, train.y.shape)
    self.assertEqual(new_train.X.shape[-1], fp_size)
+31 −4
Original line number Diff line number Diff line
@@ -1141,3 +1141,30 @@ class ANITransformer(Transformer):
  def get_num_feats(self):
    n_feat = self.outputs.get_shape().as_list()[-1]
    return n_feat


class FeaturizationTransformer(Transformer):
  """
  A transformer which runs a featurizer over the X values of a dataset.
  Datasets used by this transformer must have rdkit.mol objects as the X
  values
  """

  def __init__(self,
               transform_X=False,
               transform_y=False,
               transform_w=False,
               dataset=None,
               featurizer=None):
    self.featurizer = featurizer
    if not transform_X:
      raise ValueError("FeaturizingTransfomer can only be used on X")
    super(FeaturizationTransformer, self).__init__(
        transform_X=transform_X,
        transform_y=transform_y,
        transform_w=transform_w,
        dataset=dataset)

  def transform_array(self, X, y, w):
    X = self.featurizer.featurize(X)
    return X, y, w