Commit 2ae7b775 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

tfp optional

'
parent b4d8b3e3
Loading
Loading
Loading
Loading
+16 −1
Original line number Diff line number Diff line
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import collections
from tensorflow.keras import activations, initializers, backend
@@ -2184,6 +2183,17 @@ class WeaveLayer(tf.keras.layers.Layer):


class WeaveGather(tf.keras.layers.Layer):
  """Implements the weave-gathering section of weave convolutions.

  Implements the gathering layer from the following paper:

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.

  The weave gathering layer gathers	per-atom features to create a
  molecule-level fingerprint in a weave convolutional network. This layer can
  also perform Gaussian histogram expansion as detailed in the original paper.
  """

  def __init__(self,
               batch_size,
@@ -2208,6 +2218,11 @@ class WeaveGather(tf.keras.layers.Layer):
    activation: str, optional
      Activation function applied
    """
    try:
      import tensorflow_probability as tfp
    except ModuleNotFoundError:
      raise ValueError(
          "This class requires tensorflow-probability to be installed.")
    super(WeaveGather, self).__init__(**kwargs)
    self.n_input = n_input
    self.batch_size = batch_size
+11 −2
Original line number Diff line number Diff line
@@ -4,7 +4,6 @@ from deepchem.models import KerasModel
from deepchem.models.optimizers import Adam
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import collections
import copy
import multiprocessing
@@ -40,10 +39,20 @@ class A2CLossDiscrete(object):


class A2CLossContinuous(object):
  """This class computes the loss function for A2C with continuous action spaces."""
  """This class computes the loss function for A2C with continuous action spaces.

  Note
  ----
  This class requires tensorflow-probability to be installed.
  """

  def __init__(self, value_weight, entropy_weight, mean_index, std_index,
               value_index):
    try:
      import tensorflow_probability as tfp
    except ModuleNotFoundError:
      raise ValueError(
          "This class requires tensorflow-probability to be installed.")
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    self.mean_index = mean_index