Commit 8cbcb5b7 authored by Peter Eastman's avatar Peter Eastman
Browse files

Still trying to make yapf happy

parent e9c700ea
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -294,7 +294,6 @@ class TensorflowGraphModel(Model):

      return loss


  def fit(self,
          dataset,
          nb_epoch=10,
@@ -340,7 +339,8 @@ class TensorflowGraphModel(Model):
        def enqueue(sess, dataset, nb_epoch, epoch_end_indices):
          index = 0
          for epoch in range(nb_epoch):
            for X_b, y_b, w_b, ids_b in dataset.iterbatches(self.batch_size, pad_batches=self.pad_batches):
            for X_b, y_b, w_b, ids_b in dataset.iterbatches(
                self.batch_size, pad_batches=self.pad_batches):
              feed_dict = self.construct_feed_dict(X_b, y_b, w_b, ids_b)
              sess.run(self.train_graph.graph.enqueue, feed_dict=feed_dict)
              index += 1
@@ -348,7 +348,8 @@ class TensorflowGraphModel(Model):
          sess.run(self.train_graph.graph.queue.close())

        epoch_end_indices = []
        enqueue_thread = threading.Thread(target=enqueue, args=[sess, dataset, nb_epoch, epoch_end_indices])
        enqueue_thread = threading.Thread(
            target=enqueue, args=[sess, dataset, nb_epoch, epoch_end_indices])
        enqueue_thread.daemon = True
        enqueue_thread.start()

+24 −12
Original line number Diff line number Diff line
@@ -52,10 +52,15 @@ class TensorflowMultiTaskClassifier(TensorflowClassifier):
      assert n_layers > 0, 'Must have some layers defined.'

      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(capacity=5, dtypes=[tf.float32]*(len(label_placeholders)+len(weight_placeholders)+1))
        graph.enqueue = graph.queue.enqueue([mol_features]+label_placeholders+weight_placeholders)
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
@@ -144,10 +149,15 @@ class TensorflowMultiTaskRegressor(TensorflowRegressor):
      assert n_layers > 0, 'Must have some layers defined.'

      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(capacity=5, dtypes=[tf.float32]*(len(label_placeholders)+len(weight_placeholders)+1))
        graph.enqueue = graph.queue.enqueue([mol_features]+label_placeholders+weight_placeholders)
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
@@ -375,7 +385,8 @@ class TensorflowMultiTaskFitTransformRegressor(TensorflowMultiTaskRegressor):
        def enqueue(sess, dataset, nb_epoch, epoch_end_indices):
          index = 0
          for epoch in range(nb_epoch):
            for X_b, y_b, w_b, ids_b in dataset.iterbatches(self.batch_size, pad_batches=self.pad_batches):
            for X_b, y_b, w_b, ids_b in dataset.iterbatches(
                self.batch_size, pad_batches=self.pad_batches):
              for transformer in self.fit_transformers:
                X_b = transformer.X_transform(X_b)
              feed_dict = self.construct_feed_dict(X_b, y_b, w_b, ids_b)
@@ -385,7 +396,8 @@ class TensorflowMultiTaskFitTransformRegressor(TensorflowMultiTaskRegressor):
          sess.run(self.train_graph.graph.queue.close())

        epoch_end_indices = []
        enqueue_thread = threading.Thread(target=enqueue, args=[sess, dataset, nb_epoch, epoch_end_indices])
        enqueue_thread = threading.Thread(
            target=enqueue, args=[sess, dataset, nb_epoch, epoch_end_indices])
        enqueue_thread.daemon = True
        enqueue_thread.start()