Unverified Commit 1b171da6 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1267 from lilleswing/hashable-featurizers

Hashable Featurizers
parents efa6e259 93645bd3
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -94,3 +94,18 @@ class CircularFingerprint(Featurizer):
          useBondTypes=self.bonds,
          useFeatures=self.features)
    return fp

  def __hash__(self):
    return hash((self.radius, self.size, self.chiral, self.bonds, self.features,
                 self.sparse, self.smiles))

  def __eq__(self, other):
    if not isinstance(self, other.__class__):
      return False
    return self.radius == other.radius and \
           self.size == other.size and \
           self.chiral == other.chiral and \
           self.bonds == other.bonds and \
           self.features == other.features and \
           self.sparse == other.sparse and \
           self.smiles == other.smiles
+11 −0
Original line number Diff line number Diff line
@@ -358,6 +358,17 @@ class ConvMolFeaturizer(Featurizer):
  def feature_length(self):
    return 75 + len(self.atom_properties)

  def __hash__(self):
    atom_properties = tuple(self.atom_properties)
    return hash((self.master_atom, self.use_chirality, atom_properties))

  def __eq__(self, other):
    if not isinstance(self, other.__class__):
      return False
    return self.master_atom == other.master_atom and \
           self.use_chirality == other.use_chirality and \
           tuple(self.atom_properties) == tuple(other.atom_properties)


class WeaveFeaturizer(Featurizer):
  name = ['weave_mol']
+35 −2
Original line number Diff line number Diff line
"""
Test featurizer class.
"""
import numpy as np
import unittest

from deepchem.feat import ConvMolFeaturizer, CircularFingerprint
from deepchem.feat.basic import MolecularWeight
from rdkit import Chem

from deepchem.feat.basic import MolecularWeight

class TestFeaturizer(unittest.TestCase):
  """
  Tests for Featurizer.
  """

  def setUp(self):
    """
    Set up tests.
@@ -34,3 +35,35 @@ class TestFeaturizer(unittest.TestCase):
    f = MolecularWeight()
    rval = f([self.mol])
    assert rval.shape == (1, 1)

  def test_convmol_hashable(self):
    featurizer1 = ConvMolFeaturizer(atom_properties=['feature'])
    featurizer2 = ConvMolFeaturizer(atom_properties=['feature'])
    featurizer3 = ConvMolFeaturizer()

    d = set()
    d.add(featurizer1)
    d.add(featurizer2)
    d.add(featurizer3)

    self.assertEqual(2, len(d))
    featurizers = [featurizer1, featurizer2, featurizer3]

    for featurizer in featurizers:
      self.assertTrue(featurizer in featurizers)

  def test_circularfingerprint_hashable(self):
    featurizer1 = CircularFingerprint()
    featurizer2 = CircularFingerprint()
    featurizer3 = CircularFingerprint(size=5)

    d = set()
    d.add(featurizer1)
    d.add(featurizer2)
    d.add(featurizer3)

    self.assertEqual(2, len(d))
    featurizers = [featurizer1, featurizer2, featurizer3]

    for featurizer in featurizers:
      self.assertTrue(featurizer in featurizers)