Commit 4efb180c authored by Bharat123rox's avatar Bharat123rox
Browse files

Refactor code in contrib/tensorflow_models/

parent 81309ea8
Loading
Loading
Loading
Loading
+10 −5
Original line number Diff line number Diff line
@@ -6,6 +6,11 @@ import warnings
import numpy as np
import tensorflow as tf

try:
  from collections.abc import Sequence as SequenceCollection
except:
  from collections import Sequence as SequenceCollection

from deepchem.nn import model_ops

class RobustMultitaskClassifier(MultiTaskClassifier):
@@ -73,15 +78,15 @@ class RobustMultitaskClassifier(MultiTaskClassifier):

    n_layers = len(layer_sizes)
    assert n_layers == len(bypass_layer_sizes)
    if not isinstance(weight_init_stddevs, collections.Sequence):
    if not isinstance(weight_init_stddevs, SequenceCollection):
      weight_init_stddevs = [weight_init_stddevs] * n_layers
    if not isinstance(bypass_weight_init_stddevs, collections.Sequence):
    if not isinstance(bypass_weight_init_stddevs, SequenceCollection):
      bypass_weight_init_stddevs = [bypass_weight_init_stddevs] * n_layers
    if not isinstance(bias_init_consts, collections.Sequence):
    if not isinstance(bias_init_consts, SequenceCollection):
      bias_init_consts = [bias_init_consts] * n_layers
    if not isinstance(dropouts, collections.Sequence):
    if not isinstance(dropouts, SequenceCollection):
      dropouts = [dropouts] * n_layers
    if not isinstance(activation_fns, collections.Sequence):
    if not isinstance(activation_fns, SequenceCollection):
      activation_fns = [activation_fns] * n_layers

    # Add the input features.