Commit bac742c5 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes'

parent 1db83d19
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -504,7 +504,7 @@ class DAGModel(KerasModel):
        ], [y_b], [w_b])


class GraphConvKerasModel(tf.keras.Model):
class _GraphConvKerasModel(tf.keras.Model):

  def __init__(self,
               n_tasks,
@@ -525,7 +525,7 @@ class GraphConvKerasModel(tf.keras.Model):

    All arguments have the same meaning as in GraphConvModel.
    """
    super(GraphConvKerasModel, self).__init__()
    super(_GraphConvKerasModel, self).__init__()
    if mode not in ['classification', 'regression']:
      raise ValueError("mode must be either 'classification' or 'regression'")

@@ -628,7 +628,7 @@ class GraphConvModel(KerasModel):
               **kwargs):
    """The wrapper class for graph convolutions.

    Note that since the underlying GraphConvKerasModel class is
    Note that since the underlying _GraphConvKerasModel class is
    specified using imperative subclassing style, this model
    cannout make predictions for arbitrary outputs. 

@@ -662,7 +662,7 @@ class GraphConvModel(KerasModel):
    self.n_classes = n_classes
    self.batch_size = batch_size
    self.uncertainty = uncertainty
    model = GraphConvKerasModel(
    model = _GraphConvKerasModel(
        n_tasks,
        graph_conv_layers=graph_conv_layers,
        dense_layer_size=dense_layer_size,
+3 −3
Original line number Diff line number Diff line
@@ -50,7 +50,7 @@ class TestGraphModels(unittest.TestCase):
    model = GraphConvModel(
        len(tasks), batch_size=batch_size, mode='classification')

    model.fit(dataset, nb_epoch=10)
    model.fit(dataset, nb_epoch=50)
    scores = model.evaluate(dataset, [metric], transformers)
    assert scores['mean-roc_auc_score'] >= 0.9

@@ -79,7 +79,7 @@ class TestGraphModels(unittest.TestCase):
    batch_size = 50
    model = GraphConvModel(len(tasks), batch_size=batch_size, mode='regression')

    model.fit(dataset, nb_epoch=100)
    model.fit(dataset, nb_epoch=800)
    scores = model.evaluate(dataset, [metric], transformers)
    assert all(s < 0.1 for s in scores['mean_absolute_error'])

@@ -95,7 +95,7 @@ class TestGraphModels(unittest.TestCase):
        dropout=0.1,
        uncertainty=True)

    model.fit(dataset, nb_epoch=100)
    model.fit(dataset, nb_epoch=500)

    # Predict the output and uncertainty.
    pred, std = model.predict_uncertainty(dataset)