Unverified Commit 4e98fc0c authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2634 from deepchem/jax_layer

First linear layer from AlphaFold
parents 24607818 4257e5eb
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -29,7 +29,6 @@ jobs:
    - name: Build DeepChem
      run: |
        python -m pip install --upgrade pip
        pip install tensorflow'>=2.3,<2.4'
        pip install -e .
    - name: Import checking
      run: python -c "import deepchem"
@@ -142,10 +141,11 @@ jobs:
      if: ${{ (success() || failure()) && (steps.install.outcome == 'failure' || steps.install.outcome == 'success') }}
      shell: bash -l {0}
      run: DGLBACKEND=pytorch pytest -v --ignore-glob='deepchem/**/test*.py' --doctest-modules deepchem
    - name: PyTest
      if: ${{ (success() || failure()) && (steps.install.outcome == 'failure' || steps.install.outcome == 'success') }}
      shell: bash -l {0}
      run: pytest -v -m "not slow and not jax and not torch and not tensorflow" --cov=deepchem --cov-report=xml deepchem
      # These tests are handled by new CI runs
      #- name: PyTest
      #  if: ${{ (success() || failure()) && (steps.install.outcome == 'failure' || steps.install.outcome == 'success') }}
      #  shell: bash -l {0}
      #  run: pytest -v -m "not slow and not jax and not torch and not tensorflow" --cov=deepchem --cov-report=xml deepchem
    - name: Upload coverage to Codecov
      if: ${{ (success() || failure()) && (steps.install.outcome == 'failure' || steps.install.outcome == 'success') }}
      uses: codecov/codecov-action@v1
+2 −2
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ Contains class for gaussian process hyperparameter optimizations.
import os
import logging
import tempfile
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, Any

from deepchem.data import Dataset
from deepchem.trans import Transformer
@@ -228,7 +228,7 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
    param_keys = list(param_range.keys())

    # Stores all results
    all_results = {}
    all_results: Dict[Any, Any] = {}
    # Store all model references so we don't have to reload
    all_models = {}
    # Stores all model locations
+101 −0
Original line number Diff line number Diff line
import logging
try:
  import jax.numpy as jnp
  import haiku as hk
except:
  raise ImportError('These classes require Jax and Haiku to be installed.')

logger = logging.getLogger(__name__)


class Linear(hk.Module):
  """Protein folding specific Linear Module.

  This differs from the standard Haiku Linear in a few ways:
    * It supports inputs of arbitrary rank
    * Initializers are specified by strings

  This code is adapted from DeepMind's AlphaFold code release
  (https://github.com/deepmind/alphafold).

  Examples
  --------
  >>> import deepchem as dc
  >>> import haiku as hk
  >>> import jax
  >>> import deepchem.models.jax_models.layers
  >>> def forward_model(x):
  ...   layer = dc.models.jax_models.layers.Linear(2)
  ...   return layer(x)
  >>> f = hk.transform(forward_model)
  >>> rng = jax.random.PRNGKey(42)
  >>> x = jnp.ones([8, 28 * 28])
  >>> params = f.init(rng, x)
  >>> output = f.apply(params, rng, x)
  """

  def __init__(self,
               num_output: int,
               initializer: str = 'linear',
               use_bias: bool = True,
               bias_init: float = 0.,
               name: str = 'linear'):
    """Constructs Linear Module.

    Parameters
    ----------
    num_output: int
      number of output channels.
    initializer: str (default 'linear')
      What initializer to use, should be one of {'linear', 'relu', 'zeros'}
    use_bias: bool (default True)
      Whether to include trainable bias
    bias_init: float (default 0)
      Value used to initialize bias.
    name: str (default 'linear')
      name of module, used for name scopes.
    """

    super().__init__(name=name)
    self.num_output = num_output
    self.initializer = initializer
    self.use_bias = use_bias
    self.bias_init = bias_init

  def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
    """Connects Module.

    Parameters
    ----------
    inputs: jnp.ndarray
      Tensor of shape [..., num_channel]

    Returns
    -------
    output of shape [..., num_output]
    """
    n_channels = int(inputs.shape[-1])

    weight_shape = [n_channels, self.num_output]
    if self.initializer == 'linear':
      weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.)
    elif self.initializer == 'relu':
      weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.)
    elif self.initializer == 'zeros':
      weight_init = hk.initializers.Constant(0.0)

    weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
                               weight_init)

    # this is equivalent to einsum('...c,cd->...d', inputs, weights)
    # but turns out to be slightly faster
    inputs = jnp.swapaxes(inputs, -1, -2)
    output = jnp.einsum('...cb,cd->...db', inputs, weights)
    output = jnp.swapaxes(output, -1, -2)

    if self.use_bias:
      bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
                              hk.initializers.Constant(self.bias_init))
      output += bias

    return output
+29 −0
Original line number Diff line number Diff line
import pytest
import deepchem as dc
import numpy as np

try:
  import jax
  import jax.numpy as jnp
  from jax import random
  import haiku as hk
except:
  has_haiku_and_optax = False


@pytest.mark.jax
def test_linear():
  import deepchem as dc
  import haiku as hk
  import deepchem.models.jax_models.layers

  def forward(x):
    layer = dc.models.jax_models.layers.Linear(2)
    return layer(x)

  forward = hk.transform(forward)
  rng = jax.random.PRNGKey(42)
  x = jnp.ones([8, 28 * 28])
  params = forward.init(rng, x)
  output = forward.apply(params, rng, x)
  assert output.shape == (8, 2)
+2 −1
Original line number Diff line number Diff line
@@ -64,9 +64,10 @@ class MultiHeadedMATAttention(nn.Module):
  --------
  >>> import deepchem as dc
  >>> from rdkit import Chem
  >>> mol = rdkit.Chem.MolFromSmiles("CC")
  >>> mol = Chem.MolFromSmiles("CC")
  >>> adj_matrix = Chem.GetAdjacencyMatrix(mol)
  >>> distance_matrix = Chem.GetDistanceMatrix(mol)
  >>> import deepchem.models.torch_models.layers
  >>> layer = dc.models.torch_models.layers.MultiHeadedMATAttention(dist_kernel='softmax', lambda_attention=0.33, lambda_distance=0.33, h=2, hsize=2, dropout_p=0.0)
  >>> input_tensor = torch.tensor([[1., 2.], [5., 6.]])
  >>> mask = torch.tensor([[1., 1.], [1., 1.]])
Loading