Commit 58fb0143 authored by leswing's avatar leswing
Browse files

CR Updates

parent 8d122fac
Loading
Loading
Loading
Loading
+3 −6
Original line number Diff line number Diff line
@@ -13,10 +13,7 @@ from deepchem.models.models import Model

class TensorGraph(Model):

  def __init__(self,
               tensorboard=False,
               learning_rate=0.001,
               **kwargs):
  def __init__(self, tensorboard=False, learning_rate=0.001, **kwargs):
    """
    TODO(LESWING) allow a model to change its learning rate
    TODO(LESWING) DOCUMENTATION AND TESTING
@@ -294,7 +291,6 @@ class TensorGraph(Model):
  def get_num_tasks(self):
    return len(self.labels)


  @staticmethod
  def load_from_dir(model_dir):
    pickle_name = os.path.join(model_dir, "model.pickle")
@@ -310,8 +306,9 @@ class MultiTaskTensorGraph(TensorGraph):
  classification metrics
  """

  def __init__(self, **kwargs):
  def __init__(self, mode='classification', **kwargs):
    self.task_weights = None
    self.mode = mode
    super().__init__(**kwargs)

  def set_task_weights(self, layer):
+1 −3
Original line number Diff line number Diff line
@@ -44,10 +44,8 @@ class TestTensorGraph(unittest.TestCase):
    g.set_loss(loss)
    g.add_output(dense)


    g.fit(dataset, nb_epoch=100)
    g.save()
    g1 = TensorGraph.load_from_dir('/tmp/tmpss5_ki5_')
    prediction = g1.predict_on_batch(X)
    assert (np.sum(prediction) > 9.9)
+2 −2
Original line number Diff line number Diff line
@@ -2,8 +2,8 @@
# Used to make a conda environment with deepchem

# Change commented out line For gpu tensorflow
export tensorflow=tensorflow-gpu
#export tensorflow=tensorflow
#export tensorflow=tensorflow-gpu
export tensorflow=tensorflow

if [ -z "$1" ]
then