Commit 31217396 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #531 from miaecle/DTNN

auPRC metric and doc strings for graph models
parents 4050c527 097727e4
Loading
Loading
Loading
Loading
+12 −2
Original line number Original line Diff line number Diff line
@@ -11,6 +11,8 @@ from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import precision_score
from sklearn.metrics import precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from scipy.stats import pearsonr
from scipy.stats import pearsonr




@@ -70,6 +72,14 @@ def pearson_r2_score(y, y_pred):
  return pearsonr(y, y_pred)[0]**2
  return pearsonr(y, y_pred)[0]**2




def prc_auc_score(y, y_pred):
  """Compute area under precision-recall curve"""
  assert y_pred.shape == y.shape
  assert y_pred.shape[1] == 2
  precision, recall, _ = precision_recall_curve(y[:, 1], y_pred[:, 1])
  return auc(recall, precision)


def rms_score(y_true, y_pred):
def rms_score(y_true, y_pred):
  """Computes RMS error."""
  """Computes RMS error."""
  return np.sqrt(mean_squared_error(y_true, y_pred))
  return np.sqrt(mean_squared_error(y_true, y_pred))
@@ -148,7 +158,7 @@ class Metric(object):
      if self.metric.__name__ in [
      if self.metric.__name__ in [
          "roc_auc_score", "matthews_corrcoef", "recall_score",
          "roc_auc_score", "matthews_corrcoef", "recall_score",
          "accuracy_score", "kappa_score", "precision_score",
          "accuracy_score", "kappa_score", "precision_score",
          "balanced_accuracy_score"
          "balanced_accuracy_score", "prc_auc_score"
      ]:
      ]:
        mode = "classification"
        mode = "classification"
      elif self.metric.__name__ in [
      elif self.metric.__name__ in [
@@ -267,7 +277,7 @@ class Metric(object):
      # TODO(rbharath): This has been a major source of bugs. Is there a more
      # TODO(rbharath): This has been a major source of bugs. Is there a more
      # robust characterization of which metrics require class-probs and which
      # robust characterization of which metrics require class-probs and which
      # don't?
      # don't?
      if "roc_auc_score" in self.name:
      if "roc_auc_score" in self.name or "prc_auc_score" in self.name:
        y_true = to_one_hot(y_true).astype(int)
        y_true = to_one_hot(y_true).astype(int)
        y_pred = np.reshape(y_pred, (n_samples, n_classes))
        y_pred = np.reshape(y_pred, (n_samples, n_classes))
      else:
      else:
+61 −12
Original line number Original line Diff line number Diff line
@@ -72,6 +72,9 @@ class WeaveLayer(Layer):
      Number of features for each pair of atoms in output.
      Number of features for each pair of atoms in output.
    n_hidden_XX: int
    n_hidden_XX: int
      Number of units(convolution depths) in corresponding hidden layer
      Number of units(convolution depths) in corresponding hidden layer
    update_pair: bool, optional
      Whether to calculate for pair features, 
      could be turned off for last layer
    init: str, optional
    init: str, optional
      Weight initialization for filters.
      Weight initialization for filters.
    activation: str, optional
    activation: str, optional
@@ -198,8 +201,14 @@ class WeaveGather(Layer):
    ----------
    ----------
    batch_size: int
    batch_size: int
      number of molecules in a batch
      number of molecules in a batch
    n_input: int, optional
      number of features for each input molecule
    gaussian_expand: boolean. optional
    gaussian_expand: boolean. optional
      Whether to expand each dimension of atomic features by gaussian histogram
      Whether to expand each dimension of atomic features by gaussian histogram
    init: str, optional
      Weight initialization for filters.
    activation: str, optional
      Activation function applied


    """
    """
    self.n_input = n_input
    self.n_input = n_input
@@ -266,6 +275,16 @@ class DTNNEmbedding(Layer):
               periodic_table_length=83,
               periodic_table_length=83,
               init='glorot_uniform',
               init='glorot_uniform',
               **kwargs):
               **kwargs):
    """
    Parameters
    ----------
    n_embedding: int, optional
      Number of features for each atom
    periodic_table_length: int, optional
      Length of embedding, 83=Bi
    init: str, optional
      Weight initialization for filters.
    """
    self.n_embedding = n_embedding
    self.n_embedding = n_embedding
    self.periodic_table_length = periodic_table_length
    self.periodic_table_length = periodic_table_length
    self.init = initializations.get(init)  # Set weight initialization
    self.init = initializations.get(init)  # Set weight initialization
@@ -300,6 +319,20 @@ class DTNNStep(Layer):
               init='glorot_uniform',
               init='glorot_uniform',
               activation='tanh',
               activation='tanh',
               **kwargs):
               **kwargs):
    """
    Parameters
    ----------
    n_embedding: int, optional
      Number of features for each atom
    n_distance: int, optional
      granularity of distance matrix
    n_hidden: int, optional
      Number of nodes in hidden layer
    init: str, optional
      Weight initialization for filters.
    activation: str, optional
      Activation function applied
    """
    self.n_embedding = n_embedding
    self.n_embedding = n_embedding
    self.n_distance = n_distance
    self.n_distance = n_distance
    self.n_hidden = n_hidden
    self.n_hidden = n_hidden
@@ -365,6 +398,20 @@ class DTNNGather(Layer):
               init='glorot_uniform',
               init='glorot_uniform',
               activation='tanh',
               activation='tanh',
               **kwargs):
               **kwargs):
    """
    Parameters
    ----------
    n_embedding: int, optional
      Number of features for each atom
    n_outputs: int, optional
      Number of features for each molecule(output)
    layer_sizes: list of int, optional(default=[1000])
      Structure of hidden layer(s)
    init: str, optional
      Weight initialization for filters.
    activation: str, optional
      Activation function applied
    """
    self.n_embedding = n_embedding
    self.n_embedding = n_embedding
    self.n_outputs = n_outputs
    self.n_outputs = n_outputs
    self.layer_sizes = layer_sizes
    self.layer_sizes = layer_sizes
@@ -413,20 +460,22 @@ class DAGLayer(Layer):
  def __init__(self,
  def __init__(self,
               n_graph_feat=30,
               n_graph_feat=30,
               n_atom_feat=75,
               n_atom_feat=75,
               max_atoms=50,
               layer_sizes=[100],
               layer_sizes=[100],
               init='glorot_uniform',
               init='glorot_uniform',
               activation='relu',
               activation='relu',
               dropout=None,
               dropout=None,
               max_atoms=50,
               batch_size=64,
               batch_size=64,
               **kwargs):
               **kwargs):
    """
    """
    Parameters
    Parameters
    ----------
    ----------
    n_graph_feat: int
    n_graph_feat: int, optional
      Number of features for each node(and the whole grah).
      Number of features for each node(and the whole grah).
    n_atom_feat: int
    n_atom_feat: int, optional
      Number of features listed per atom.
      Number of features listed per atom.
    max_atoms: int, optional
      Maximum number of atoms in molecules.
    layer_sizes: list of int, optional(default=[1000])
    layer_sizes: list of int, optional(default=[1000])
      Structure of hidden layer(s)
      Structure of hidden layer(s)
    init: str, optional
    init: str, optional
@@ -435,8 +484,8 @@ class DAGLayer(Layer):
      Activation function applied
      Activation function applied
    dropout: float, optional
    dropout: float, optional
      Dropout probability, not supported here
      Dropout probability, not supported here
    max_atoms: int, optional
    batch_size: int, optional
      Maximum number of atoms in molecules.
      number of molecules in a batch
    """
    """
    super(DAGLayer, self).__init__(**kwargs)
    super(DAGLayer, self).__init__(**kwargs)


@@ -552,20 +601,22 @@ class DAGGather(Layer):
  def __init__(self,
  def __init__(self,
               n_graph_feat=30,
               n_graph_feat=30,
               n_outputs=30,
               n_outputs=30,
               max_atoms=50,
               layer_sizes=[100],
               layer_sizes=[100],
               init='glorot_uniform',
               init='glorot_uniform',
               activation='relu',
               activation='relu',
               dropout=None,
               dropout=None,
               max_atoms=50,
               **kwargs):
               **kwargs):
    """
    """
    Parameters
    Parameters
    ----------
    ----------
    n_graph_feat: int
    n_graph_feat: int, optional
      Number of features for each atom
      Number of features for each atom
    n_outputs: int
    n_outputs: int, optional
      Number of features for each molecule.
      Number of features for each molecule.
    layer_sizes: list of int, optional(default=[1000])
    max_atoms: int, optional
      Maximum number of atoms in molecules.
    layer_sizes: list of int, optional
      Structure of hidden layer(s)
      Structure of hidden layer(s)
    init: str, optional
    init: str, optional
      Weight initialization for filters.
      Weight initialization for filters.
@@ -573,8 +624,6 @@ class DAGGather(Layer):
      Activation function applied
      Activation function applied
    dropout: float, optional
    dropout: float, optional
      Dropout probability, not supported
      Dropout probability, not supported
    max_atoms: int, optional
      Maximum number of atoms in molecules.
    """
    """
    super(DAGGather, self).__init__(**kwargs)
    super(DAGGather, self).__init__(**kwargs)


@@ -608,7 +657,7 @@ class DAGGather(Layer):


  def _create_tensor(self):
  def _create_tensor(self):
    """description and explanation refer to deepchem.nn.DAGGather
    """description and explanation refer to deepchem.nn.DAGGather
    parent layers: atom_features
    parent layers: atom_features, membership
    """
    """
    # Add trainable weights
    # Add trainable weights
    self.build()
    self.build()
+66 −0
Original line number Original line Diff line number Diff line
@@ -21,6 +21,21 @@ class WeaveTensorGraph(TensorGraph):
               n_hidden=50,
               n_hidden=50,
               n_graph_feat=128,
               n_graph_feat=128,
               **kwargs):
               **kwargs):
    """
    Parameters
    ----------
    n_tasks: int
      Number of tasks
    n_atom_feat: int, optional
      Number of features per atom.
    n_pair_feat: int, optional
      Number of features per pair of atoms.
    n_hidden: int, optional
      Number of units(convolution depths) in corresponding hidden layer
    n_graph_feat: int, optional
      Number of output features for each molecule(graph)

    """
    self.n_tasks = n_tasks
    self.n_tasks = n_tasks
    self.n_atom_feat = n_atom_feat
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat
    self.n_pair_feat = n_pair_feat
@@ -30,6 +45,9 @@ class WeaveTensorGraph(TensorGraph):
    self.build_graph()
    self.build_graph()


  def build_graph(self):
  def build_graph(self):
    """Building graph structures:
    Features => WeaveLayer => WeaveLayer => Dense => WeaveGather => Classification or Regression
    """
    self.atom_features = Feature(shape=(None, self.n_atom_feat))
    self.atom_features = Feature(shape=(None, self.n_atom_feat))
    self.pair_features = Feature(shape=(None, self.n_pair_feat))
    self.pair_features = Feature(shape=(None, self.n_pair_feat))
    combined = Combine_AP(in_layers=[self.atom_features, self.pair_features])
    combined = Combine_AP(in_layers=[self.atom_features, self.pair_features])
@@ -94,6 +112,9 @@ class WeaveTensorGraph(TensorGraph):
                        epochs=1,
                        epochs=1,
                        predict=False,
                        predict=False,
                        pad_batches=True):
                        pad_batches=True):
    """ TensorGraph style implementation
    similar to deepchem.models.tf_new_models.graph_topology.AlternateWeaveTopology.batch_to_feed_dict
    """
    for epoch in range(epochs):
    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          batch_size=self.batch_size,
@@ -190,6 +211,24 @@ class DTNNTensorGraph(TensorGraph):
               distance_min=-1,
               distance_min=-1,
               distance_max=18,
               distance_max=18,
               **kwargs):
               **kwargs):
    """
    Parameters
    ----------
    n_tasks: int
      Number of tasks
    n_embedding: int, optional
      Number of features per atom.
    n_hidden: int, optional
      Number of features for each molecule after DTNNStep
    n_distance: int, optional
      granularity of distance matrix
      step size will be (distance_max-distance_min)/n_distance
    distance_min: float, optional
      minimum distance of atom pairs, default = -1 Angstorm
    distance_max: float, optional
      maximum distance of atom pairs, default = 18 Angstorm

    """
    self.n_tasks = n_tasks
    self.n_tasks = n_tasks
    self.n_embedding = n_embedding
    self.n_embedding = n_embedding
    self.n_hidden = n_hidden
    self.n_hidden = n_hidden
@@ -205,6 +244,9 @@ class DTNNTensorGraph(TensorGraph):
    self.build_graph()
    self.build_graph()


  def build_graph(self):
  def build_graph(self):
    """Building graph structures:
    Features => DTNNEmbedding => DTNNStep => DTNNStep => DTNNGather => Regression
    """
    self.atom_number = Feature(shape=(None,), dtype=tf.int32)
    self.atom_number = Feature(shape=(None,), dtype=tf.int32)
    self.distance = Feature(shape=(None, self.n_distance))
    self.distance = Feature(shape=(None, self.n_distance))
    self.atom_membership = Feature(shape=(None,), dtype=tf.int32)
    self.atom_membership = Feature(shape=(None,), dtype=tf.int32)
@@ -254,6 +296,9 @@ class DTNNTensorGraph(TensorGraph):
                        epochs=1,
                        epochs=1,
                        predict=False,
                        predict=False,
                        pad_batches=True):
                        pad_batches=True):
    """ TensorGraph style implementation
    similar to deepchem.models.tf_new_models.graph_topology.DTNNGraphTopology.batch_to_feed_dict
    """
    for epoch in range(epochs):
    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          batch_size=self.batch_size,
@@ -348,6 +393,21 @@ class DAGTensorGraph(TensorGraph):
               n_graph_feat=30,
               n_graph_feat=30,
               n_outputs=30,
               n_outputs=30,
               **kwargs):
               **kwargs):
    """
    Parameters
    ----------
    n_tasks: int
      Number of tasks
    max_atoms: int, optional
      Maximum number of atoms in a molecule, should be defined based on dataset
    n_atom_feat: int, optional
      Number of features per atom.
    n_graph_feat: int, optional
      Number of features for atom in the graph
    n_outputs: int, optional
      Number of features for each molecule

    """
    self.n_tasks = n_tasks
    self.n_tasks = n_tasks
    self.max_atoms = max_atoms
    self.max_atoms = max_atoms
    self.n_atom_feat = n_atom_feat
    self.n_atom_feat = n_atom_feat
@@ -357,6 +417,9 @@ class DAGTensorGraph(TensorGraph):
    self.build_graph()
    self.build_graph()


  def build_graph(self):
  def build_graph(self):
    """Building graph structures:
    Features => DAGLayer => DAGGather => Classification or Regression
    """
    self.atom_features = Feature(shape=(None, self.n_atom_feat))
    self.atom_features = Feature(shape=(None, self.n_atom_feat))
    self.parents = Feature(
    self.parents = Feature(
        shape=(None, self.max_atoms, self.max_atoms), dtype=tf.int32)
        shape=(None, self.max_atoms, self.max_atoms), dtype=tf.int32)
@@ -414,6 +477,9 @@ class DAGTensorGraph(TensorGraph):
                        epochs=1,
                        epochs=1,
                        predict=False,
                        predict=False,
                        pad_batches=True):
                        pad_batches=True):
    """ TensorGraph style implementation
    similar to deepchem.models.tf_new_models.graph_topology.DAGGraphTopology.batch_to_feed_dict
    """
    for epoch in range(epochs):
    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          batch_size=self.batch_size,
+174 −11
Original line number Original line Diff line number Diff line
@@ -705,6 +705,40 @@ class TestOverfit(test_util.TensorFlowTestCase):


    assert scores[regression_metric.name] > .9
    assert scores[regression_metric.name] > .9


  def test_tensorgraph_DTNN_multitask_regression_overfit(self):
    """Test deep tensor neural net overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)

    # Load mini log-solubility dataset.
    input_file = os.path.join(self.current_dir, "example_DTNN.mat")
    dataset = scipy.io.loadmat(input_file)
    X = dataset['X']
    y = dataset['T']
    w = np.ones_like(y)
    dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids=None)
    regression_metric = dc.metrics.Metric(
        dc.metrics.pearson_r2_score, task_averager=np.mean)
    n_tasks = y.shape[1]
    batch_size = 10

    model = dc.models.DTNNTensorGraph(
        n_tasks,
        n_embedding=20,
        n_distance=100,
        batch_size=batch_size,
        learning_rate=0.001,
        use_queue=False,
        mode="regression")

    # Fit trained model
    model.fit(dataset, nb_epoch=20)

    # Eval model on train
    scores = model.evaluate(dataset, [regression_metric])

    assert scores[regression_metric.name] > .9

  def test_DAG_singletask_regression_overfit(self):
  def test_DAG_singletask_regression_overfit(self):
    """Test DAG regressor multitask overfits tiny data."""
    """Test DAG regressor multitask overfits tiny data."""
    np.random.seed(123)
    np.random.seed(123)
@@ -729,7 +763,7 @@ class TestOverfit(test_util.TensorFlowTestCase):


    graph = dc.nn.SequentialDAGGraph(n_atom_feat=n_feat, max_atoms=50)
    graph = dc.nn.SequentialDAGGraph(n_atom_feat=n_feat, max_atoms=50)
    graph.add(dc.nn.DAGLayer(30, n_feat, max_atoms=50, batch_size=batch_size))
    graph.add(dc.nn.DAGLayer(30, n_feat, max_atoms=50, batch_size=batch_size))
    graph.add(dc.nn.DAGGather(max_atoms=50))
    graph.add(dc.nn.DAGGather(30, max_atoms=50))


    model = dc.models.MultitaskGraphRegressor(
    model = dc.models.MultitaskGraphRegressor(
        graph,
        graph,
@@ -750,6 +784,44 @@ class TestOverfit(test_util.TensorFlowTestCase):


    assert scores[regression_metric.name] > .8
    assert scores[regression_metric.name] > .8


  def test_tensorgraph_DAG_singletask_regression_overfit(self):
    """Test DAG regressor multitask overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)
    n_tasks = 1

    # Load mini log-solubility dataset.
    featurizer = dc.feat.ConvMolFeaturizer()
    tasks = ["outcome"]
    input_file = os.path.join(self.current_dir, "example_regression.csv")
    loader = dc.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)
    dataset = loader.featurize(input_file)

    regression_metric = dc.metrics.Metric(
        dc.metrics.pearson_r2_score, task_averager=np.mean)

    n_feat = 75
    batch_size = 10
    transformer = dc.trans.DAGTransformer(max_atoms=50)
    dataset = transformer.transform(dataset)

    model = dc.models.DAGTensorGraph(
        n_tasks,
        max_atoms=50,
        n_atom_feat=n_feat,
        batch_size=batch_size,
        learning_rate=0.001,
        use_queue=False,
        mode="regression")

    # Fit trained model
    model.fit(dataset, nb_epoch=50)
    # Eval model on train
    scores = model.evaluate(dataset, [regression_metric])

    assert scores[regression_metric.name] > .8

  def test_weave_singletask_classification_overfit(self):
  def test_weave_singletask_classification_overfit(self):
    """Test weave model overfits tiny data."""
    """Test weave model overfits tiny data."""
    np.random.seed(123)
    np.random.seed(123)
@@ -772,12 +844,18 @@ class TestOverfit(test_util.TensorFlowTestCase):
    batch_size = 10
    batch_size = 10
    max_atoms = 50
    max_atoms = 50


    graph = dc.nn.SequentialWeaveGraph(
    graph = dc.nn.AlternateSequentialWeaveGraph(
        max_atoms=max_atoms, n_atom_feat=n_atom_feat, n_pair_feat=n_pair_feat)
        batch_size,
    graph.add(dc.nn.WeaveLayer(max_atoms, 75, 14))
        max_atoms=max_atoms,
    graph.add(dc.nn.WeaveConcat(batch_size, n_output=n_feat))
        n_atom_feat=n_atom_feat,
        n_pair_feat=n_pair_feat)
    graph.add(dc.nn.AlternateWeaveLayer(max_atoms, 75, 14))
    graph.add(dc.nn.AlternateWeaveLayer(max_atoms, 50, 50, update_pair=False))
    graph.add(dc.nn.Dense(n_feat, 50, activation='tanh'))
    graph.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph.add(dc.nn.WeaveGather(batch_size, n_input=n_feat))
    graph.add(
        dc.nn.AlternateWeaveGather(
            batch_size, n_input=n_feat, gaussian_expand=True))


    model = dc.models.MultitaskGraphClassifier(
    model = dc.models.MultitaskGraphClassifier(
        graph,
        graph,
@@ -799,6 +877,45 @@ class TestOverfit(test_util.TensorFlowTestCase):


    assert scores[classification_metric.name] > .65
    assert scores[classification_metric.name] > .65


  def test_tensorgraph_weave_singletask_classification_overfit(self):
    """Test weave model overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)
    n_tasks = 1

    # Load mini log-solubility dataset.
    featurizer = dc.feat.WeaveFeaturizer()
    tasks = ["outcome"]
    input_file = os.path.join(self.current_dir, "example_classification.csv")
    loader = dc.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)
    dataset = loader.featurize(input_file)

    classification_metric = dc.metrics.Metric(dc.metrics.accuracy_score)

    n_atom_feat = 75
    n_pair_feat = 14
    n_feat = 128
    batch_size = 10

    model = dc.models.WeaveTensorGraph(
        n_tasks,
        n_atom_feat=n_atom_feat,
        n_pair_feat=n_pair_feat,
        n_graph_feat=n_feat,
        batch_size=batch_size,
        learning_rate=0.001,
        use_queue=False,
        mode="classification")

    # Fit trained model
    model.fit(dataset, nb_epoch=20)

    # Eval model on train
    scores = model.evaluate(dataset, [classification_metric])

    assert scores[classification_metric.name] > .65

  def test_weave_singletask_regression_overfit(self):
  def test_weave_singletask_regression_overfit(self):
    """Test weave model overfits tiny data."""
    """Test weave model overfits tiny data."""
    np.random.seed(123)
    np.random.seed(123)
@@ -822,12 +939,18 @@ class TestOverfit(test_util.TensorFlowTestCase):
    batch_size = 10
    batch_size = 10
    max_atoms = 50
    max_atoms = 50


    graph = dc.nn.SequentialWeaveGraph(
    graph = dc.nn.AlternateSequentialWeaveGraph(
        max_atoms=max_atoms, n_atom_feat=n_atom_feat, n_pair_feat=n_pair_feat)
        batch_size,
    graph.add(dc.nn.WeaveLayer(max_atoms, 75, 14))
        max_atoms=max_atoms,
    graph.add(dc.nn.WeaveConcat(batch_size, n_output=n_feat))
        n_atom_feat=n_atom_feat,
        n_pair_feat=n_pair_feat)
    graph.add(dc.nn.AlternateWeaveLayer(max_atoms, 75, 14))
    graph.add(dc.nn.AlternateWeaveLayer(max_atoms, 50, 50, update_pair=False))
    graph.add(dc.nn.Dense(n_feat, 50, activation='tanh'))
    graph.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph.add(dc.nn.WeaveGather(batch_size, n_input=n_feat))
    graph.add(
        dc.nn.AlternateWeaveGather(
            batch_size, n_input=n_feat, gaussian_expand=True))


    model = dc.models.MultitaskGraphRegressor(
    model = dc.models.MultitaskGraphRegressor(
        graph,
        graph,
@@ -849,6 +972,46 @@ class TestOverfit(test_util.TensorFlowTestCase):


    assert scores[regression_metric.name] > .9
    assert scores[regression_metric.name] > .9


  def test_tensorgraph_weave_singletask_regression_overfit(self):
    """Test weave model overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)
    n_tasks = 1

    # Load mini log-solubility dataset.
    featurizer = dc.feat.WeaveFeaturizer()
    tasks = ["outcome"]
    input_file = os.path.join(self.current_dir, "example_regression.csv")
    loader = dc.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)
    dataset = loader.featurize(input_file)

    regression_metric = dc.metrics.Metric(
        dc.metrics.pearson_r2_score, task_averager=np.mean)

    n_atom_feat = 75
    n_pair_feat = 14
    n_feat = 128
    batch_size = 10

    model = dc.models.WeaveTensorGraph(
        n_tasks,
        n_atom_feat=n_atom_feat,
        n_pair_feat=n_pair_feat,
        n_graph_feat=n_feat,
        batch_size=batch_size,
        learning_rate=0.001,
        use_queue=False,
        mode="regression")

    # Fit trained model
    model.fit(dataset, nb_epoch=120)

    # Eval model on train
    scores = model.evaluate(dataset, [regression_metric])

    assert scores[regression_metric.name] > .9

  def test_siamese_singletask_classification_overfit(self):
  def test_siamese_singletask_classification_overfit(self):
    """Test siamese singletask model overfits tiny data."""
    """Test siamese singletask model overfits tiny data."""
    np.random.seed(123)
    np.random.seed(123)
+24 −4
Original line number Original line Diff line number Diff line
@@ -90,8 +90,6 @@ class SequentialDTNNGraph(SequentialGraph):
    """
    """
    Parameters
    Parameters
    ----------
    ----------
    max_n_atoms: int
      maximum number of atoms in a molecule
    n_distance: int, optional
    n_distance: int, optional
      granularity of distance matrix
      granularity of distance matrix
      step size will be (distance_max-distance_min)/n_distance
      step size will be (distance_max-distance_min)/n_distance
@@ -130,9 +128,9 @@ class SequentialDAGGraph(SequentialGraph):
    """
    """
    Parameters
    Parameters
    ----------
    ----------
    n_atom_feat: int
    n_atom_feat: int, optional
      Number of features per atom.
      Number of features per atom.
    max_atoms: int, optional(default=50)
    max_atoms: int, optional
      Maximum number of atoms in a molecule, should be defined based on dataset
      Maximum number of atoms in a molecule, should be defined based on dataset
    """
    """
    self.graph = tf.Graph()
    self.graph = tf.Graph()
@@ -161,6 +159,16 @@ class SequentialWeaveGraph(SequentialGraph):
  """
  """


  def __init__(self, max_atoms=50, n_atom_feat=75, n_pair_feat=14):
  def __init__(self, max_atoms=50, n_atom_feat=75, n_pair_feat=14):
    """
    Parameters
    ----------
    max_atoms: int, optional
      Maximum number of atoms in a molecule, should be defined based on dataset
    n_atom_feat: int, optional
      Number of features per atom.
    n_pair_feat: int, optional
      Number of features per pair of atoms.
    """
    self.graph = tf.Graph()
    self.graph = tf.Graph()
    self.max_atoms = max_atoms
    self.max_atoms = max_atoms
    self.n_atom_feat = n_atom_feat
    self.n_atom_feat = n_atom_feat
@@ -195,6 +203,18 @@ class AlternateSequentialWeaveGraph(SequentialGraph):
  """
  """


  def __init__(self, batch_size, max_atoms=50, n_atom_feat=75, n_pair_feat=14):
  def __init__(self, batch_size, max_atoms=50, n_atom_feat=75, n_pair_feat=14):
    """
    Parameters
    ----------
    batch_size: int
      number of molecules in a batch
    max_atoms: int, optional
      Maximum number of atoms in a molecule, should be defined based on dataset
    n_atom_feat: int, optional
      Number of features per atom.
    n_pair_feat: int, optional
      Number of features per pair of atoms.
    """
    self.graph = tf.Graph()
    self.graph = tf.Graph()
    self.batch_size = batch_size
    self.batch_size = batch_size
    self.max_atoms = max_atoms
    self.max_atoms = max_atoms
Loading