Commit e392e133 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

remove cruft

parent 3a5cfde3
Loading
Loading
Loading
Loading
+2 −122
Original line number Diff line number Diff line
@@ -14,15 +14,6 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.training import moving_averages
from collections import defaultdict
# TODO(rbharath): What does this line do?
py_all = all

# TODO(rbharath): REMOVE GLOBAL VARS! BREAKS DEEPCHEM STYLE!
_UID_PREFIXES = defaultdict(int)
# This dictionary holds a mapping {graph: learning_phase}.
# A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
_GRAPH_LEARNING_PHASES = {}


def _to_tensor(x, dtype):
@@ -32,102 +23,6 @@ def _to_tensor(x, dtype):
  return x


def learning_phase():
  """Returns the learning phase flag.

  The learning phase flag is a bool tensor (0 = test, 1 = train)
  to be passed as input to any Keras function
  that uses a different behavior at train time and test time.
  """
  graph = tf.get_default_graph()
  if graph not in _GRAPH_LEARNING_PHASES:
    phase = tf.placeholder(dtype='bool', name='keras_learning_phase')
    _GRAPH_LEARNING_PHASES[graph] = phase
  return _GRAPH_LEARNING_PHASES[graph]


def in_train_phase(x, alt):
  """Selects `x` in train phase, and `alt` otherwise.
  Note that `alt` should have the *same shape* as `x`.

  Returns
  -------
  Either `x` or `alt` based on `K.learning_phase`.
  """
  if learning_phase() is 1:
    return x
  elif learning_phase() is 0:
    return alt
  # else: assume learning phase is a placeholder tensor.
  x = switch(learning_phase(), x, alt)
  x._uses_learning_phase = True
  return x


def switch(condition, then_expression, else_expression):
  """Switches between two operations
  depending on a scalar value (`int` or `bool`).
  Note that both `then_expression` and `else_expression`
  should be symbolic tensors of the *same shape*.

  Parameters
  ----------
  condition: scalar tensor.
  then_expression: either a tensor, or a callable that returns a tensor.
  else_expression: either a tensor, or a callable that returns a tensor.

  Returns
  -------
  The selected tensor.
  """
  if condition.dtype != tf.bool:
    condition = tf.cast(condition, 'bool')
  if not callable(then_expression):

    def then_expression_fn():
      return then_expression
  else:
    then_expression_fn = then_expression
  if not callable(else_expression):

    def else_expression_fn():
      return else_expression
  else:
    else_expression_fn = else_expression
  x = tf.cond(condition, then_expression_fn, else_expression_fn)
  return x


def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
  """Computes mean and std for batch then apply batch_normalization on batch.

  Returns
  -------
  A tuple length of 3, (normalized_tensor, mean, variance).
  """
  mean, var = tf.nn.moments(
      x, reduction_axes, shift=None, name=None, keep_dims=False)
  if sorted(reduction_axes) == range(ndim(x))[:-1]:
    normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
  else:
    # need broadcasting
    target_shape = []
    for axis in range(get_ndim(x)):
      if axis in reduction_axes:
        target_shape.append(1)
      else:
        target_shape.append(tf.shape(x)[axis])
    target_shape = stack(target_shape)

    broadcast_mean = tf.reshape(mean, target_shape)
    broadcast_var = tf.reshape(var, target_shape)
    broadcast_gamma = tf.reshape(gamma, target_shape)
    broadcast_beta = tf.reshape(beta, target_shape)
    normed = tf.nn.batch_normalization(x, broadcast_mean, broadcast_var,
                                       broadcast_beta, broadcast_gamma, epsilon)
  return normed, mean, var


def ones(shape, dtype=None, name=None):
  """Instantiates an all-ones tensor variable and returns it.

@@ -186,21 +81,6 @@ def int_shape(x):
  return tuple([i.__int__() for i in shape])


def get_uid(prefix=''):
  """Provides a unique UID given a string prefix.

  Parameters
  ----------
  prefix: string.

  Returns
  -------
  An integer.
  """
  _UID_PREFIXES[prefix] += 1
  return _UID_PREFIXES[prefix]


def concatenate(tensors, axis=-1):
  """Concatenates a list of tensors alongside the specified axis.

@@ -586,8 +466,8 @@ def cosine_distances(test, support):
  tf.Tensor:
    Of shape (n_test, n_support)
  """
  rnorm_test = tf.rsqrt(
      tf.reduce_sum(tf.square(test), 1, keep_dims=True)) + 1e-7
  rnorm_test = tf.rsqrt(tf.reduce_sum(tf.square(test), 1,
                                      keep_dims=True)) + 1e-7
  rnorm_support = tf.rsqrt(
      tf.reduce_sum(tf.square(support), 1, keep_dims=True)) + 1e-7
  test_normalized = test * rnorm_test