Commit 01931f0c authored by nitinprakash96's avatar nitinprakash96
Browse files

Merge branch 'master' of https://github.com/deepchem/deepchem into acnn

parents f1bda88e 30b03541
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -17,15 +17,17 @@ from deepchem.models.tensorgraph.IRV import TensorflowMultiTaskIRVClassifier
from deepchem.models.tensorgraph.robust_multitask import RobustMultitaskClassifier
from deepchem.models.tensorgraph.robust_multitask import RobustMultitaskRegressor
from deepchem.models.tensorgraph.progressive_multitask import ProgressiveMultitaskRegressor, ProgressiveMultitaskClassifier
from deepchem.models.tensorgraph.models.graph_models import WeaveModel, DTNNTensorGraph, DAGTensorGraph, GraphConvModel, MPNNTensorGraph
from deepchem.models.tensorgraph.models.graph_models import WeaveModel, DTNNModel, DAGModel, GraphConvModel, MPNNModel
from deepchem.models.tensorgraph.models.symmetry_function_regression import BPSymmetryFunctionRegression, ANIRegression
from deepchem.models.tensorgraph.models.scscore import ScScoreModel

from deepchem.models.tensorgraph.models.seqtoseq import SeqToSeq
from deepchem.models.tensorgraph.models.gan import GAN, WGAN
from deepchem.models.tensorgraph.models.text_cnn import TextCNNTensorGraph
from deepchem.models.tensorgraph.models.text_cnn import TextCNNModel
from deepchem.models.tensorgraph.sequential import Sequential
from deepchem.models.tensorgraph.models.sequence_dnn import SequenceDNN

#################### Compatibility imports for renamed TensorGraph models. Remove below with DeepChem 3.0. ####################

from deepchem.models.tensorgraph.models.graph_models import WeaveTensorGraph, GraphConvTensorGraph
 No newline at end of file
from deepchem.models.tensorgraph.models.text_cnn import TextCNNTensorGraph
from deepchem.models.tensorgraph.models.graph_models import WeaveTensorGraph, DTNNTensorGraph, DAGTensorGraph, GraphConvTensorGraph, MPNNTensorGraph
+1 −1
Original line number Diff line number Diff line
@@ -925,7 +925,7 @@ class EdgeNetwork(object):

  def forward(self, atom_features, atom_to_pair):
    out = tf.expand_dims(tf.gather(atom_features, atom_to_pair[:, 1]), 2)
    out = tf.reduce_sum(out * self.A, axis=1)
    out = tf.squeeze(tf.matmul(self.A, out), axis=2)
    out = tf.segment_sum(out, atom_to_pair[:, 0])
    return out

+5 −6
Original line number Diff line number Diff line
@@ -4,8 +4,7 @@ from __future__ import unicode_literals

import numpy as np
import tensorflow as tf
from deepchem.models.tensorgraph.model_ops import random_uniform_variable
from deepchem.models.tensorgraph.model_ops import random_normal_variable
from deepchem.models.tensorgraph.model_ops import random_uniform_variable, random_normal_variable, create_variable
from deepchem.models.tensorgraph.activations import get_from_module


@@ -94,7 +93,7 @@ def orthogonal(shape, scale=1.1, name=None):
  # Pick the one with the correct shape.
  q = u if u.shape == flat_shape else v
  q = q.reshape(shape)
  return tf.Variable(
  return create_variable(
      scale * q[:shape[0], :shape[1]], dtype=tf.float32, name=name)


@@ -103,16 +102,16 @@ def identity(shape, scale=1, name=None):
    raise ValueError('Identity matrix initialization can only be used '
                     'for 2D square matrices.')
  else:
    return tf.Variable(
    return create_variable(
        scale * np.identity(shape[0]), dtype=tf.float32, name=name)


def zero(shape, name=None):
  return tf.Variable(tf.zeros(shape), dtype=tf.float32, name=name)
  return create_variable(tf.zeros(shape), dtype=tf.float32, name=name)


def one(shape, name=None):
  return tf.Variable(tf.ones(shape), dtype=tf.float32, name=name)
  return create_variable(tf.ones(shape), dtype=tf.float32, name=name)


def get(identifier, **kwargs):
+504 −191

File changed.

Preview size limit exceeded, changes collapsed.

+31 −22
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ import sys
import traceback
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.python.training import moving_averages
from collections import defaultdict

@@ -22,28 +23,36 @@ def _to_tensor(x, dtype):
  return x


def create_variable(value, dtype=None, name=None):
  """Create a tf.Variable or tfe.Variable, depending on the current mode."""
  if tfe.in_eager_mode():
    return tfe.Variable(value, dtype=dtype, name=name)
  else:
    return tf.Variable(value, dtype=dtype, name=name)


def ones(shape, dtype=None, name=None):
  """Instantiates an all-ones tensor variable and returns it.

  Parameters
  ----------
  shape: Tuple of integers, shape of returned Keras variable.
  shape: Tuple of integers, shape of returned Tensorflow variable.
  dtype: Tensorflow dtype
  name: String, name of returned Keras variable.
  name: String, name of returned Tensorflow variable.

  Returns
  -------
  A Keras variable, filled with `1.0`.
  A Tensorflow variable, filled with `1.0`.
  """
  if dtype is None:
    dtype = tf.float32
  shape = tuple(map(int, shape))
  return tf.Variable(
  return create_variable(
      tf.constant_initializer(1., dtype=dtype)(shape), dtype, name)


def cast_to_floatx(x):
  """Cast a Numpy array to the default Keras float type.
  """Cast a Numpy array to the default Tensorflow float type.

  Parameters
  ----------
@@ -65,7 +74,7 @@ def moving_average_update(variable, value, momentum):


def int_shape(x):
  """Returns the shape of a Keras tensor or a Keras variable as a tuple of
  """Returns the shape of a Tensorflow tensor or a Tensorflow variable as a tuple of
  integers or None entries.

  Arguments
@@ -193,7 +202,7 @@ def get_ndim(x):


def get_dtype(x):
  """Returns the dtype of a Keras tensor or variable, as a string.
  """Returns the dtype of a Tensorflow tensor or variable, as a string.

  Parameters
  ----------
@@ -259,7 +268,7 @@ def random_uniform_variable(shape,
    seed = np.random.randint(10e8)
  value = tf.random_uniform_initializer(
      low, high, dtype=dtype, seed=seed)(shape)
  return tf.Variable(value, dtype=dtype, name=name)
  return create_variable(value, dtype=dtype, name=name)


def random_normal_variable(shape,
@@ -268,16 +277,16 @@ def random_normal_variable(shape,
                           dtype=tf.float32,
                           name=None,
                           seed=None):
  """Instantiates an Keras variable filled with
  """Instantiates an Tensorflow variable filled with
  samples drawn from a normal distribution and returns it.

  Parameters
  ----------
  shape: Tuple of integers, shape of returned Keras variable.
  shape: Tuple of integers, shape of returned Tensorflow variable.
  mean: Float, mean of the normal distribution.
  scale: Float, standard deviation of the normal distribution.
  dtype: Tensorflow dtype
  name: String, name of returned Keras variable.
  name: String, name of returned Tensorflow variable.
  seed: Integer, random seed.

  Returns
@@ -290,7 +299,7 @@ def random_normal_variable(shape,
    seed = np.random.randint(10e8)
  value = tf.random_normal_initializer(
      mean, scale, dtype=dtype, seed=seed)(shape)
  return tf.Variable(value, dtype=dtype, name=name)
  return create_variable(value, dtype=dtype, name=name)


def max(x, axis=None, keepdims=False):
@@ -338,7 +347,7 @@ def categorical_crossentropy(output, target, from_logits=False):
  # TODO(rbharath): Should probably swap this over to tf mode.
  """
  # Note: tf.nn.softmax_cross_entropy_with_logits
  # expects logits, Keras expects probabilities.
  # expects logits, Tensorflow expects probabilities.
  if not from_logits:
    # scale preds so that the class probas of each sample sum to 1
    output /= tf.reduce_sum(
@@ -362,7 +371,7 @@ def sparse_categorical_crossentropy(output, target, from_logits=False):
  and a target tensor, where the target is an integer tensor.
  """
  # Note: tf.nn.softmax_cross_entropy_with_logits
  # expects logits, Keras expects probabilities.
  # expects logits, Tensorflow expects probabilities.
  if not from_logits:
    epsilon = _to_tensor(_EPSILON, output.dtype.base_dtype)
    output = tf.clip_by_value(output, epsilon, 1 - epsilon)
@@ -398,7 +407,7 @@ def binary_crossentropy(output, target, from_logits=False):
      A tensor.
  """
  # Note: tf.nn.softmax_cross_entropy_with_logits
  # expects logits, Keras expects probabilities.
  # expects logits, Tensorflow expects probabilities.
  if not from_logits:
    # transform back to logits
    epsilon = _to_tensor(_EPSILON, output.dtype.base_dtype)
@@ -437,16 +446,16 @@ def zeros(shape, dtype=tf.float32, name=None):

  Parameters
  ----------
  shape: Tuple of integers, shape of returned Keras variable
  shape: Tuple of integers, shape of returned Tensorflow variable
  dtype: Tensorflow dtype
  name: String, name of returned Keras variable
  name: String, name of returned Tensorflow variable

  Returns
  -------
  A variable (including Keras metadata), filled with `0.0`.
  A variable (including Tensorflow metadata), filled with `0.0`.
  """
  shape = tuple(map(int, shape))
  return tf.Variable(
  return create_variable(
      tf.constant_initializer(0., dtype=dtype)(shape), dtype, name)


@@ -674,7 +683,7 @@ def add_bias(tensor, init=None, name=None):
  if init is None:
    init = tf.zeros([tensor.get_shape()[-1].value])
  with tf.name_scope(name, tensor.op.name, [tensor]):
    b = tf.Variable(init, name='b')
    b = create_variable(init, name='b')
    return tf.nn.bias_add(tensor, b)


@@ -752,8 +761,8 @@ def fully_connected_layer(tensor,
    bias_init = tf.zeros([size])

  with tf.name_scope(name, 'fully_connected', [tensor]):
    w = tf.Variable(weight_init, name='w', dtype=tf.float32)
    b = tf.Variable(bias_init, name='b', dtype=tf.float32)
    w = create_variable(weight_init, name='w', dtype=tf.float32)
    b = create_variable(bias_init, name='b', dtype=tf.float32)
    return tf.nn.xw_plus_b(tensor, w, b)


Loading