Commit 2c74d385 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fix to padding

parent 372d514e
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -180,7 +180,7 @@ class TensorflowGraph(object):
        self.updates = tf.no_op(name='updates')

  def fit(self, dataset, shuffle=False, max_checkpoints_to_keep=5,
          log_every_N_batches=50, pad_batches=False):
          log_every_N_batches=50):
    """Fit the model.

    Args:
@@ -199,6 +199,10 @@ class TensorflowGraph(object):
    step_per_epoch = np.ceil(float(num_datapoints)/batch_size)
    nb_epoch = self.model_params["nb_epoch"]
    log("Training for %d epochs" % nb_epoch, self.verbosity)
    if "pad_batches" in self.model_params:
      pad_batches = self.model_params["pad_batches"]
    else:
      pad_batches = False
    with self.graph.as_default():
      self.require_attributes(['loss', 'updates'])
      train_op = self.get_training_op()
@@ -575,11 +579,11 @@ class TensorflowModel(Model):
    self.num_tasks = len(self.task_types)
    self.fit_transformers = None

  def fit(self, dataset, shuffle=False, pad_batches=False):
  def fit(self, dataset, shuffle=False):
    """
    Fits TensorflowGraph to data.
    """
    self.train_model.fit(dataset, shuffle=shuffle, pad_batches=pad_batches)
    self.train_model.fit(dataset, shuffle=shuffle)

  def predict_on_batch(self, X):
    """