Commit 6f065998 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

bypass tensorflow imports

parent 340a9658
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
# flake8: noqa

try:
  from deepchem.metalearning.maml import MAML, MetaLearner
except ModuleNotFoundError:
  pass
+31 −26
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
Gathers all models in one place for convenient imports
"""
# flake8: noqa
try:
  from deepchem.models.models import Model
  from deepchem.models.keras_model import KerasModel
  from deepchem.models.multitask import SingletaskToMultitask
@@ -24,6 +25,8 @@ from deepchem.models.cnn import CNN
  from deepchem.models.text_cnn import TextCNNModel
  from deepchem.models.atomic_conv import AtomicConvModel
  from deepchem.models.chemnet_models import Smiles2Vec, ChemCeption
except ModuleNotFoundError:
  pass

# scikit-learn model
from deepchem.models.sklearn_models import SklearnModel
@@ -50,7 +53,9 @@ from deepchem.models.gbdt_models.gbdt_model import XGBoostModel
########################################################################################
# Compatibility imports for renamed TensorGraph models. Remove below with DeepChem 3.0.
########################################################################################

try:
  from deepchem.models.text_cnn import TextCNNTensorGraph
  from deepchem.models.graph_models import WeaveTensorGraph, DTNNTensorGraph, DAGTensorGraph, GraphConvTensorGraph, MPNNTensorGraph
  from deepchem.models.IRV import TensorflowMultitaskIRVClassifier
except ModuleNotFoundError:
  pass
+5 −3
Original line number Diff line number Diff line
"""Interface for reinforcement learning."""

try:
  from deepchem.rl.a2c import A2C  # noqa: F401
  from deepchem.rl.ppo import PPO  # noqa: F401
except ModuleNotFoundError:
  pass


class Environment(object):
+7 −2
Original line number Diff line number Diff line
@@ -10,7 +10,6 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import scipy
import scipy.ndimage
import tensorflow as tf

import deepchem as dc
from deepchem.data import Dataset, NumpyDataset, DiskDataset
@@ -1611,6 +1610,7 @@ class IRVTransformer(Transformer):
      n_samples * np.array of size (2*K,)
      each array includes K similarity values and corresponding labels
    """
    import tensorflow as tf
    features = []
    similarity_xs = similarity * np.sign(w)
    [target_len, reference_len] = similarity_xs.shape
@@ -1994,6 +1994,7 @@ class ANITransformer(Transformer):
    """
    Only X can be transformed
    """
    import tensorflow as tf
    self.max_atoms = max_atoms
    self.radial_cutoff = radial_cutoff
    self.angular_cutoff = angular_cutoff
@@ -2038,6 +2039,7 @@ class ANITransformer(Transformer):

  def build(self):
    """ tensorflow computation graph for transform """
    import tensorflow as tf
    graph = tf.Graph()
    with graph.as_default():
      self.inputs = tf.keras.Input(
@@ -2065,6 +2067,7 @@ class ANITransformer(Transformer):

  def distance_matrix(self, coordinates, flags):
    """ Generate distance matrix """
    import tensorflow as tf
    max_atoms = self.max_atoms
    tensor1 = tf.stack([coordinates] * max_atoms, axis=1)
    tensor2 = tf.stack([coordinates] * max_atoms, axis=2)
@@ -2077,6 +2080,7 @@ class ANITransformer(Transformer):

  def distance_cutoff(self, d, cutoff, flags):
    """ Generate distance matrix with trainable cutoff """
    import tensorflow as tf
    # Cutoff with threshold Rc
    d_flag = flags * tf.sign(cutoff - d)
    d_flag = tf.nn.relu(d_flag)
@@ -2087,6 +2091,7 @@ class ANITransformer(Transformer):

  def radial_symmetry(self, d_cutoff, d, atom_numbers):
    """ Radial Symmetry Function """
    import tensorflow as tf
    embedding = tf.eye(np.max(self.atom_cases) + 1)
    atom_numbers_embedded = tf.nn.embedding_lookup(embedding, atom_numbers)

@@ -2113,7 +2118,7 @@ class ANITransformer(Transformer):

  def angular_symmetry(self, d_cutoff, d, atom_numbers, coordinates):
    """ Angular Symmetry Function """

    import tensorflow as tf
    max_atoms = self.max_atoms
    embedding = tf.eye(np.max(self.atom_cases) + 1)
    atom_numbers_embedded = tf.nn.embedding_lookup(embedding, atom_numbers)
+1 −3
Original line number Diff line number Diff line
@@ -10,9 +10,7 @@ else:
  IS_RELEASE = False

# Environment-specific dependencies.
extras = {
  'jax': ['git+https://github.com/deepmind/dm-haiku', 'git+git://github.com/deepmind/optax.git']
}
extras = {'jax': ['dm-haiku==0.0.3', 'optax==0.0.8']}


# get the version from deepchem/__init__.py