Unverified Commit 2a72d8e1 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1452 from peastman/pad

pad_batch() correctly handles labels with more than two dimensions
parents 9219d4ed dc222769
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -99,14 +99,14 @@ def pad_batch(batch_size, X_b, y_b, w_b, ids_b):
  elif len(y_b.shape) < 2:
    y_out = np.zeros(batch_size, dtype=y_b.dtype)
  else:
    y_out = np.zeros((batch_size, y_b.shape[1]), dtype=y_b.dtype)
    y_out = np.zeros((batch_size,) + y_b.shape[1:], dtype=y_b.dtype)

  if w_b is None:
    w_out = None
  elif len(w_b.shape) < 2:
    w_out = np.zeros(batch_size, dtype=w_b.dtype)
  else:
    w_out = np.zeros((batch_size, w_b.shape[1]), dtype=w_b.dtype)
    w_out = np.zeros((batch_size,) + w_b.shape[1:], dtype=w_b.dtype)

  ids_out = np.zeros((batch_size,), dtype=ids_b.dtype)