Commit 4db6f6f5 authored by Vignesh's avatar Vignesh
Browse files

Removed type casting from DTNNEmbedding; moved it to create_estimators

parent f1415632
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -338,7 +338,6 @@ class DTNNEmbedding(Layer):
    self.build()

    atom_number = in_layers[0].out_tensor
    atom_number = tf.cast(atom_number, dtype=tf.int32)
    atom_features = tf.nn.embedding_lookup(self.embedding_list, atom_number)
    out_tensor = atom_features
    if set_tensors:
+6 −1
Original line number Diff line number Diff line
@@ -239,7 +239,12 @@ class TextCNNModel(TensorGraph):
    """Creates tensors for inputs."""
    tensors = dict()
    for layer, column in zip(self.features, feature_columns):
      tensors[layer] = tf.feature_column.input_layer(features, [column])
      feature_col = tf.feature_column.input_layer(features, [column])
      if column.dtype != feature_col.dtype:
        feature_col = tf.cast(feature_col, column.dtype)
      if len(column.shape) < 1:
        feature_col = tf.reshape(feature_col, shape=[tf.shape(feature_col)[0]])
      tensors[layer] = feature_col
    if weight_column is not None:
      tensors[self.task_weights[0]] = tf.feature_column.input_layer(
          features, [weight_column])