Commit 93645bd3 authored by leswing's avatar leswing
Browse files

Hashable featurizers

parent 2817b137
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -94,3 +94,18 @@ class CircularFingerprint(Featurizer):
          useBondTypes=self.bonds,
          useFeatures=self.features)
    return fp

  def __hash__(self):
    return hash((self.radius, self.size, self.chiral, self.bonds, self.features,
                 self.sparse, self.smiles))

  def __eq__(self, other):
    if not isinstance(self, other.__class__):
      return False
    return self.radius == other.radius and \
           self.size == other.size and \
           self.chiral == other.chiral and \
           self.bonds == other.bonds and \
           self.features == other.features and \
           self.sparse == other.sparse and \
           self.smiles == other.smiles
+11 −0
Original line number Diff line number Diff line
@@ -358,6 +358,17 @@ class ConvMolFeaturizer(Featurizer):
  def feature_length(self):
    return 75 + len(self.atom_properties)

  def __hash__(self):
    atom_properties = tuple(self.atom_properties)
    return hash((self.master_atom, self.use_chirality, atom_properties))

  def __eq__(self, other):
    if not isinstance(self, other.__class__):
      return False
    return self.master_atom == other.master_atom and \
           self.use_chirality == other.use_chirality and \
           tuple(self.atom_properties) == tuple(other.atom_properties)


class WeaveFeaturizer(Featurizer):
  name = ['weave_mol']
+35 −2
Original line number Diff line number Diff line
"""
Test featurizer class.
"""
import numpy as np
import unittest

from deepchem.feat import ConvMolFeaturizer, CircularFingerprint
from deepchem.feat.basic import MolecularWeight
from rdkit import Chem

from deepchem.feat.basic import MolecularWeight

class TestFeaturizer(unittest.TestCase):
  """
  Tests for Featurizer.
  """

  def setUp(self):
    """
    Set up tests.
@@ -34,3 +35,35 @@ class TestFeaturizer(unittest.TestCase):
    f = MolecularWeight()
    rval = f([self.mol])
    assert rval.shape == (1, 1)

  def test_convmol_hashable(self):
    featurizer1 = ConvMolFeaturizer(atom_properties=['feature'])
    featurizer2 = ConvMolFeaturizer(atom_properties=['feature'])
    featurizer3 = ConvMolFeaturizer()

    d = set()
    d.add(featurizer1)
    d.add(featurizer2)
    d.add(featurizer3)

    self.assertEqual(2, len(d))
    featurizers = [featurizer1, featurizer2, featurizer3]

    for featurizer in featurizers:
      self.assertTrue(featurizer in featurizers)

  def test_circularfingerprint_hashable(self):
    featurizer1 = CircularFingerprint()
    featurizer2 = CircularFingerprint()
    featurizer3 = CircularFingerprint(size=5)

    d = set()
    d.add(featurizer1)
    d.add(featurizer2)
    d.add(featurizer3)

    self.assertEqual(2, len(d))
    featurizers = [featurizer1, featurizer2, featurizer3]

    for featurizer in featurizers:
      self.assertTrue(featurizer in featurizers)