Commit 8a1ad36d authored by peastman's avatar peastman
Browse files

Fixed calculation of convolution output shape

parent d5dad129
Loading
Loading
Loading
Loading
+31 −21
Original line number Diff line number Diff line
@@ -446,6 +446,16 @@ class SharedVariableScope(Layer):
      return self._shared_with._get_scope_name()


def _conv_size(width, size, stride, padding):
  """Compute the output size of a convolutional layer."""
  if padding.lower() == 'valid':
    return 1 + (width - size) // stride
  elif padding.lower() == 'same':
    return 1 + (width - 1) // stride
  else:
    raise ValueError('Unknown padding type: %s' % padding)


class Conv1D(Layer):
  """A 1D convolution on the input.

@@ -544,11 +554,11 @@ class Conv1D(Layer):
      parent_shape = self.in_layers[0].shape
      if isinstance(strides, int):
        strides = (strides,)
      if padding.lower() == 'same':
      if isinstance(kernel_size, int):
        kernel_size = (kernel_size,)
      self._shape = (parent_shape[0],
                       int(np.ceil(parent_shape[1] / strides[0])), filters)
      else:
        self._shape = (parent_shape[0], parent_shape[1] // strides[0], filters)
                     _conv_size(parent_shape[1], kernel_size[0], strides[0],
                                padding), filters)
    except:
      pass

@@ -1993,13 +2003,13 @@ class Conv2D(SharedVariableScope):
      strides = stride
      if isinstance(stride, int):
        strides = (stride, stride)
      if padding.lower() == 'same':
      if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
      self._shape = (parent_shape[0],
                       int(np.ceil(parent_shape[1] / strides[0])),
                       int(np.ceil(parent_shape[2] / strides[1])), num_outputs)
      else:
        self._shape = (parent_shape[0], parent_shape[1] // strides[0],
                       parent_shape[2] // strides[1], num_outputs)
                     _conv_size(parent_shape[1], kernel_size[0], strides[0],
                                padding),
                     _conv_size(parent_shape[2], kernel_size[1], strides[1],
                                padding), num_outputs)
    except:
      pass

@@ -2115,15 +2125,15 @@ class Conv3D(SharedVariableScope):
      strides = stride
      if isinstance(stride, int):
        strides = (stride, stride, stride)
      if padding.lower() == 'same':
      if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
      self._shape = (parent_shape[0],
                       int(np.ceil(parent_shape[1] / strides[0])),
                       int(np.ceil(parent_shape[2] / strides[1])),
                       int(np.ceil(parent_shape[3] / strides[2])), num_outputs)
      else:
        self._shape = (parent_shape[0], parent_shape[1] // strides[0],
                       parent_shape[2] // strides[1],
                       parent_shape[3] // strides[2], num_outputs)
                     _conv_size(parent_shape[1], kernel_size[0], strides[0],
                                padding),
                     _conv_size(parent_shape[2], kernel_size[1], strides[1],
                                padding),
                     _conv_size(parent_shape[3], kernel_size[2], strides[2],
                                padding), num_outputs)
    except:
      pass