Unverified Commit 53aa9724 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1286 from peastman/reshape

Improved logic for reshaping inputs to TensorGraph
parents 9b13b31a 059d2752
Loading
Loading
Loading
Loading
+30 −38
Original line number Diff line number Diff line
@@ -1062,36 +1062,22 @@ class TensorGraph(Model):
    training: bool
      True during training, False during prediction
    """
    if tfe.in_eager_mode():
    train_value = 1.0 if training else 0.0
    if self.queue_installed:
      while True:
        yield {self._training_placeholder: train_value}
    else:
      for d in generator:
        feed_dict = {}
        for key, value in d.items():
          if isinstance(key, Input):
            # Add or remove dimensions of size 1 to match the shape of the layer.
            try:
              value_dims = len(value.shape)
              layer_dims = len(key.shape)
              if value_dims < layer_dims:
                if all(i == 1 for i in key.shape[value_dims:]):
                  value = tf.reshape(value,
                                     list(value.shape) + [1] *
                                     (layer_dims - value_dims))
              if value_dims > layer_dims:
                if all(i == 1 for i in value.shape[layer_dims:]):
                  value = tf.reshape(value, value.shape[:layer_dims])
            except:
              pass
            feed_dict[key] = tf.cast(value, key.dtype)
          else:
            value = _ensure_value_shape(value, key)
            if tfe.in_eager_mode():
              value = tf.cast(value, key.dtype)
            feed_dict[key] = value
        yield feed_dict
          else:
      train_value = 1.0 if training else 0.0
      if self.queue_installed:
        while True:
          yield {self._training_placeholder: train_value}
      for d in generator:
        feed_dict = dict(d)
            feed_dict[key] = value
        if not tfe.in_eager_mode():
          feed_dict[self._training_placeholder] = train_value
        yield feed_dict

@@ -1272,6 +1258,24 @@ class TensorGraph(Model):
    return tensors


def _ensure_value_shape(value, layer):
  """Ensure that a value has the right shape for an input layer."""
  # Add or remove dimensions of size 1 to match the shape of the layer.
  try:
    value_dims = len(value.shape)
    layer_dims = len(layer.shape)
    if value_dims < layer_dims:
      if all(i == 1 for i in layer.shape[value_dims:]):
        value = value.reshape(
            list(value.shape) + [1] * (layer_dims - value_dims))
    if value_dims > layer_dims:
      if all(i == 1 for i in value.shape[layer_dims:]):
        value = value.reshape(value.shape[:layer_dims])
  except:
    pass
  return value


def _enqueue_batch(tg, generator, graph, sess, n_enqueued, final_sample):
  """
  Function to load data into
@@ -1294,19 +1298,7 @@ def _enqueue_batch(tg, generator, graph, sess, n_enqueued, final_sample):
      for layer in tg.features + tg.labels + tg.task_weights:
        if layer in feed_dict:
          value = feed_dict[layer]
          # Add or remove dimensions of size 1 to match the shape of the layer.
          try:
            value_dims = len(value.shape)
            layer_dims = len(layer.shape)
            if value_dims < layer_dims:
              if all(i == 1 for i in layer.shape[value_dims:]):
                value = value.reshape(
                    list(value.shape) + [1] * (layer_dims - value_dims))
            if value_dims > layer_dims:
              if all(i == 1 for i in value.shape[layer_dims:]):
                value = value.reshape(value.shape[:layer_dims])
          except:
            pass
          value = _ensure_value_shape(value, layer)
        else:
          value = np.zeros(
              [0] + list(layer.shape[1:]), dtype=layer.dtype.as_numpy_dtype)