Commit 0ad056ff authored by peastman's avatar peastman
Browse files

Created CGAN example

parent a03a6328
Loading
Loading
Loading
Loading
+7 −5
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ from tensorflow.python.framework.errors_impl import OutOfRangeError
from deepchem.data import NumpyDataset
from deepchem.metrics import to_one_hot, from_one_hot
from deepchem.models.models import Model
from deepchem.models.tensorgraph.layers import InputFifoQueue, Label, Feature, Weights
from deepchem.models.tensorgraph.layers import InputFifoQueue, Label, Feature, Weights, Constant
from deepchem.models.tensorgraph.optimizers import Adam
from deepchem.trans import undo_transforms
from deepchem.utils.evaluate import GeneratorEvaluator
@@ -61,7 +61,7 @@ class TensorGraph(Model):
    self.outputs = list()
    self.task_weights = list()
    self.submodels = list()
    self.loss = None
    self.loss = Constant(0)
    self.built = False
    self.queue_installed = False
    self.optimizer = Adam(
@@ -190,10 +190,13 @@ class TensorGraph(Model):
      self.build()
    with self._get_tf("Graph").as_default():
      time1 = time.time()
      loss = self.loss
      if submodel is None:
        train_op = self._get_tf('train_op')
      else:
        train_op = submodel.get_train_op()
        if submodel.loss is not None:
          loss = submodel.loss
      if checkpoint_interval > 0:
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      if restore:
@@ -222,14 +225,13 @@ class TensorGraph(Model):
        n_samples += 1
        should_log = (self.tensorboard and
                      n_samples % self.tensorboard_log_frequency == 0)
        fetches = [train_op, self.loss.out_tensor]
        fetches = [train_op, loss.out_tensor]
        if should_log:
          fetches.append(self._get_tf("summary_op"))
        fetched_values = self.session.run(fetches, feed_dict=feed_dict)
        if should_log:
          self._log_tensorboard(fetches[2])
        loss = fetched_values[1]
        avg_loss += loss
        avg_loss += fetched_values[1]
        n_averaged_batches += 1
        self.global_step += 1
        if checkpoint_interval > 0 and self.global_step % checkpoint_interval == checkpoint_interval - 1:
+308 −0

File added.

Preview size limit exceeded, changes collapsed.