Commit fe960807 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #840 from rbharath/error_bar

Graphconv error bar fix for #838
parents 65d2b7c5 359f2db8
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -669,6 +669,26 @@ class GraphConvTensorGraph(TensorGraph):

    return mu[:max_index + 1], sigma[:max_index + 1]

  def bayesian_predict_on_batch(self, X, transformers=[], n_passes=4):
    """ 
    Returns: 
      mu: numpy ndarray of shape (n_samples, n_tasks) 
      sigma: numpy ndarray of shape (n_samples, n_tasks)     
    """
    dataset = NumpyDataset(X=X, y=None, n_tasks=len(self.outputs))
    y_ = []
    for i in range(n_passes):
      generator = self.default_generator(
          dataset, predict=True, pad_batches=True)
      y_.append(self.predict_on_generator(generator, transformers))

    # Concatenates along 0-th dimension
    y_ = np.array(y_)
    mu = np.mean(y_, axis=0)
    sigma = np.std(y_, axis=0)

    return mu, sigma

  def predict_on_smiles(self, smiles, transformers=[], untransform=False):
    """Generates predictions on a numpy array of smile strings

+15 −0
Original line number Diff line number Diff line
@@ -55,3 +55,18 @@ def test_graph_conv_regression_model():
  model.save()
  model = TensorGraph.load_from_dir(model.model_dir)
  scores = model.evaluate(dataset, [metric], transformers)


def test_graph_conv_error_bars():
  tasks, dataset, transformers, metric = get_dataset('regression', 'GraphConv')

  batch_size = 50
  model = GraphConvTensorGraph(
      len(tasks), batch_size=batch_size, mode='regression')

  model.fit(dataset, nb_epoch=1)

  mu, sigma = model.bayesian_predict(
      dataset, transformers, untransform=True, n_passes=24)
  assert mu.shape == (len(dataset), len(tasks))
  assert sigma.shape == (len(dataset), len(tasks))
+1 −1
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ model.save()
model.load_from_dir('model_saves')

mu, sigma = model.bayesian_predict(
    valid_dataset.X, transformers, untransform=True, n_passes=24)
    valid_dataset, transformers, untransform=True, n_passes=24)
print(mu[:4])
print(sigma[:4])