Commit ff82a198 authored by leswing's avatar leswing
Browse files

Cover the Squeeze layer

parent 916212c9
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from tensorflow.python.framework import test_util
from deepchem.feat.mol_graphs import ConvMol
from deepchem.feat.mol_graphs import MultiConvMol
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.models.tensorgraph.layers import Conv1D
from deepchem.models.tensorgraph.layers import Conv1D, Squeeze
from deepchem.models.tensorgraph.layers import Dense
from deepchem.models.tensorgraph.layers import Flatten
from deepchem.models.tensorgraph.layers import Reshape
@@ -521,3 +521,11 @@ class TestLayers(test_util.TensorFlowTestCase):
      result = out_tensor.eval()
      assert result.shape == (1, 6, 1)
      assert np.array_equal(value1.reshape((1, 6, 1)) + value2, result)

  def test_squeeze_inputs(self):
    """Test that layers can automatically reshape inconsistent inputs."""
    value1 = np.random.uniform(size=(2, 1)).astype(np.float32)
    with self.test_session() as sess:
      out_tensor = Squeeze(squeeze_dims=1)(tf.constant(value1))
      result = out_tensor.eval()
      assert result.shape == (2,)