Unverified Commit 30c51ebb authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1826 from seyonechithrananda/patch-2

Update 22_Transfer_Learning_With_HuggingFace_tox21.ipynb
parents b08cb012 d57fb795
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
%% Cell type:markdown id: tags:

<a href="https://colab.research.google.com/github/seyonechithrananda/bert-loves-chemistry/blob/master/HuggingFace_DeepChem_final_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

%% Cell type:markdown id: tags:

# Tutorial Part 21: Finetuning HuggingFace's RoBERTa for masked language modelling of SMILES
# Tutorial Part 22: Finetuning HuggingFace's RoBERTa for Tox21 Toxicity Predictions on SMILES Strings

![alt text](https://huggingface.co/front/assets/huggingface_mask.svg)

By Seyone Chithrananda

Deep learning for chemistry and materials science remains a novel field with lots of potiential. However, the popularity of transfer learning based methods in areas such as NLP and computer vision have not yet been effectively developed in computational chemistry + machine learning. Using HuggingFace's suite of models and the ByteLevel tokenizer, we are able to train a large-transformer model, RoBERTa, on a large corpus of 100k SMILES strings from a commonly known benchmark chemistry dataset, ZINC.

Training RoBERTa over 5 epochs, the model achieves a pretty good loss of 0.398, and may likely continue to decrease if trained for a larger number of epochs. The model can predict tokens within a SMILES sequence/molecule, allowing for variants of a molecule within discoverable chemical space to be predicted.

By applying the representations of functional groups and atoms learned by the model, we can try to tackle problems of toxicity, solubility, drug-likeness, and synthesis accessibility on smaller datasets using the learned representations as features for graph convolution and attention models on the graph structure of molecules, as well as fine-tuning of BERT. Finally, we propose the use of attention visualization as a helpful tool for chemistry practitioners and students to quickly identify important substructures in various chemical properties.

Additionally, visualization of the attention mechanism have been seen through previous research as incredibly valuable towards chemical reaction classification. The applications of open-sourcing large-scale transformer models such as RoBERTa with HuggingFace may allow for the acceleration of these individual research directions.

A link to a repository which includes the training, uploading and evaluation notebook (with sample predictions on compounds such as Remdesivir) can be found [here](https://github.com/seyonechithrananda/bert-loves-chemistry). All of the notebooks can be copied into a new Colab runtime for easy execution.

For the sake of this tutorial, we'll be fine-tuning RoBERTa on a small-scale molecule dataset, to show the potiential and effectiveness of HuggingFace's NLP-based transfer learning applied to computational chemistry.

%% Cell type:markdown id: tags:

Installing DeepChem from source, alongside RDKit for molecule visualizations

%% Cell type:code id: tags:

``` 
import tensorflow as tf
print("tf.__version__: %s" % str(tf.__version__))
device_name = tf.test.gpu_device_name()
if not device_name:
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))
```

%% Output

    tf.__version__: 2.2.0-rc3
    Found GPU at: /device:GPU:0

%% Cell type:code id: tags:

``` 
!pip install transformers
```

%% Output

    Collecting transformers
    [?25l  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)
    [K     |████████████████████████████████| 573kB 12.3MB/s
    [?25hCollecting tokenizers==0.5.2
    [?25l  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
    [K     |████████████████████████████████| 3.7MB 47.4MB/s
    [?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.40)
    Collecting sentencepiece
    [?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
    [K     |████████████████████████████████| 1.0MB 27.1MB/s
    [?25hRequirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
    Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
    Collecting sacremoses
    [?25l  Downloading https://files.pythonhosted.org/packages/99/50/93509f906a40bffd7d175f97fd75ea328ad9bd91f48f59c4bd084c94a25e/sacremoses-0.0.41.tar.gz (883kB)
    [K     |████████████████████████████████| 890kB 55.6MB/s
    [?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.2)
    Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)
    Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)
    Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
    Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)
    Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)
    Requirement already satisfied: botocore<1.16.0,>=1.15.40 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.40)
    Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)
    Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.1)
    Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)
    Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)
    Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
    Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)
    Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
    Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (2.8.1)
    Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (0.15.2)
    Building wheels for collected packages: sacremoses
      Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
      Created wheel for sacremoses: filename=sacremoses-0.0.41-cp36-none-any.whl size=893334 sha256=49159e7f355e0b67097229550cdd83f0d2fdb44567a06534b35c19bfedadfcc3
      Stored in directory: /root/.cache/pip/wheels/22/5a/d4/b020a81249de7dc63758a34222feaa668dbe8ebfe9170cc9b1
    Successfully built sacremoses
    Installing collected packages: tokenizers, sentencepiece, sacremoses, transformers
    Successfully installed sacremoses-0.0.41 sentencepiece-0.1.85 tokenizers-0.5.2 transformers-2.8.0

%% Cell type:markdown id: tags:

Now, to ensure our model demonstrates an understanding of chemical syntax and molecular structure, we'll be testing it on predicting a masked token/character within the SMILES molecule for Remdesivir.

%% Cell type:code id: tags:

``` 
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline

model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
```

%% Output


    


    


    


    


    


    

%% Cell type:code id: tags:

``` 
remdesivir_mask = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=<mask>1"
remdesivir = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1"

"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1"

masked_smi = fill_mask(remdesivir_mask)

for smi in masked_smi:
  print(smi)
```

%% Output

    {'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1</s>', 'score': 0.5986586809158325, 'token': 39}
    {'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1</s>', 'score': 0.09766950458288193, 'token': 51}
    {'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=N1</s>', 'score': 0.07694468647241592, 'token': 50}
    {'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=21</s>', 'score': 0.0241263248026371, 'token': 22}
    {'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=H1</s>', 'score': 0.0188530795276165, 'token': 44}

%% Cell type:markdown id: tags:

Here, we get some interesting results. The final branch, `C1=CC=CC=C1`, is a  benzene ring. Since its a pretty common molecule, the model is easily able to predict the final double carbon bond with a score of 0.60. Let's get a list of the top 5 predictions (including the target, Remdesivir), and visualize them (with a highlighted focus on the beginning of the final benzene-like pattern). Lets import some various RDKit packages to do so.

%% Cell type:code id: tags:

``` 
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge rdkit
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')
```

%% Output

    --2020-04-24 04:23:25--  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
    Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.130.3, 104.16.131.3, 2606:4700::6810:8203, ...
    Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.130.3|:443... connected.
    HTTP request sent, awaiting response... 200 OK
    Length: 85055499 (81M) [application/x-sh]
    Saving to: ‘Miniconda3-latest-Linux-x86_64.sh’
    
    Miniconda3-latest-L 100%[===================>]  81.12M  55.6MB/s    in 1.5s
    
    2020-04-24 04:23:27 (55.6 MB/s) - ‘Miniconda3-latest-Linux-x86_64.sh’ saved [85055499/85055499]
    
    PREFIX=/usr/local
    Unpacking payload ...
    Collecting package metadata (current_repodata.json): - \ done
    Solving environment: / - done
    
    ## Package Plan ##
    
      environment location: /usr/local
    
      added / updated specs:
        - _libgcc_mutex==0.1=main
        - asn1crypto==1.3.0=py37_0
        - ca-certificates==2020.1.1=0
        - certifi==2019.11.28=py37_0
        - cffi==1.14.0=py37h2e261b9_0
        - chardet==3.0.4=py37_1003
        - conda-package-handling==1.6.0=py37h7b6447c_0
        - conda==4.8.2=py37_0
        - cryptography==2.8=py37h1ba5d50_0
        - idna==2.8=py37_0
        - ld_impl_linux-64==2.33.1=h53a641e_7
        - libedit==3.1.20181209=hc058e9b_0
        - libffi==3.2.1=hd88cf55_4
        - libgcc-ng==9.1.0=hdf63c60_0
        - libstdcxx-ng==9.1.0=hdf63c60_0
        - ncurses==6.2=he6710b0_0
        - openssl==1.1.1d=h7b6447c_4
        - pip==20.0.2=py37_1
        - pycosat==0.6.3=py37h7b6447c_0
        - pycparser==2.19=py37_0
        - pyopenssl==19.1.0=py37_0
        - pysocks==1.7.1=py37_0
        - python==3.7.6=h0371630_2
        - readline==7.0=h7b6447c_5
        - requests==2.22.0=py37_1
        - ruamel_yaml==0.15.87=py37h7b6447c_0
        - setuptools==45.2.0=py37_0
        - six==1.14.0=py37_0
        - sqlite==3.31.1=h7b6447c_0
        - tk==8.6.8=hbc83047_0
        - tqdm==4.42.1=py_0
        - urllib3==1.25.8=py37_0
        - wheel==0.34.2=py37_0
        - xz==5.2.4=h14c3975_4
        - yaml==0.1.7=had09818_2
        - zlib==1.2.11=h7b6447c_3
    
    
    The following NEW packages will be INSTALLED:
    
      _libgcc_mutex      pkgs/main/linux-64::_libgcc_mutex-0.1-main
      asn1crypto         pkgs/main/linux-64::asn1crypto-1.3.0-py37_0
      ca-certificates    pkgs/main/linux-64::ca-certificates-2020.1.1-0
      certifi            pkgs/main/linux-64::certifi-2019.11.28-py37_0
      cffi               pkgs/main/linux-64::cffi-1.14.0-py37h2e261b9_0
      chardet            pkgs/main/linux-64::chardet-3.0.4-py37_1003
      conda              pkgs/main/linux-64::conda-4.8.2-py37_0
      conda-package-han~ pkgs/main/linux-64::conda-package-handling-1.6.0-py37h7b6447c_0
      cryptography       pkgs/main/linux-64::cryptography-2.8-py37h1ba5d50_0
      idna               pkgs/main/linux-64::idna-2.8-py37_0
      ld_impl_linux-64   pkgs/main/linux-64::ld_impl_linux-64-2.33.1-h53a641e_7
      libedit            pkgs/main/linux-64::libedit-3.1.20181209-hc058e9b_0
      libffi             pkgs/main/linux-64::libffi-3.2.1-hd88cf55_4
      libgcc-ng          pkgs/main/linux-64::libgcc-ng-9.1.0-hdf63c60_0
      libstdcxx-ng       pkgs/main/linux-64::libstdcxx-ng-9.1.0-hdf63c60_0
      ncurses            pkgs/main/linux-64::ncurses-6.2-he6710b0_0
      openssl            pkgs/main/linux-64::openssl-1.1.1d-h7b6447c_4
      pip                pkgs/main/linux-64::pip-20.0.2-py37_1
      pycosat            pkgs/main/linux-64::pycosat-0.6.3-py37h7b6447c_0
      pycparser          pkgs/main/linux-64::pycparser-2.19-py37_0
      pyopenssl          pkgs/main/linux-64::pyopenssl-19.1.0-py37_0
      pysocks            pkgs/main/linux-64::pysocks-1.7.1-py37_0
      python             pkgs/main/linux-64::python-3.7.6-h0371630_2
      readline           pkgs/main/linux-64::readline-7.0-h7b6447c_5
      requests           pkgs/main/linux-64::requests-2.22.0-py37_1
      ruamel_yaml        pkgs/main/linux-64::ruamel_yaml-0.15.87-py37h7b6447c_0
      setuptools         pkgs/main/linux-64::setuptools-45.2.0-py37_0
      six                pkgs/main/linux-64::six-1.14.0-py37_0
      sqlite             pkgs/main/linux-64::sqlite-3.31.1-h7b6447c_0
      tk                 pkgs/main/linux-64::tk-8.6.8-hbc83047_0
      tqdm               pkgs/main/noarch::tqdm-4.42.1-py_0
      urllib3            pkgs/main/linux-64::urllib3-1.25.8-py37_0
      wheel              pkgs/main/linux-64::wheel-0.34.2-py37_0
      xz                 pkgs/main/linux-64::xz-5.2.4-h14c3975_4
      yaml               pkgs/main/linux-64::yaml-0.1.7-had09818_2
      zlib               pkgs/main/linux-64::zlib-1.2.11-h7b6447c_3
    
    
    Preparing transaction: | / - done
    Executing transaction: | / - \ | / - \ | / - \ | / - \ | / done
    installation finished.
    WARNING:
        You currently have a PYTHONPATH environment variable set. This may cause
        unexpected behavior when running the Python interpreter in Miniconda3.
        For best results, please verify that your PYTHONPATH only points to
        directories of packages that are compatible with the Python interpreter
        in Miniconda3: /usr/local
    Collecting package metadata (current_repodata.json): ...working... done
    Solving environment: ...working... done
    
    ## Package Plan ##
    
      environment location: /usr/local
    
      added / updated specs:
        - rdkit
    
    
    The following packages will be downloaded:
    
        package                    |            build
        ---------------------------|-----------------
        boost-1.72.0               |   py37h9de70de_0         316 KB  conda-forge
        boost-cpp-1.72.0           |       h8e57a91_0        21.8 MB  conda-forge
        bzip2-1.0.8                |       h516909a_2         396 KB  conda-forge
        ca-certificates-2020.4.5.1 |       hecc5488_0         146 KB  conda-forge
        cairo-1.16.0               |    hcf35c78_1003         1.5 MB  conda-forge
        certifi-2020.4.5.1         |   py37hc8dfbb8_0         151 KB  conda-forge
        conda-4.8.3                |   py37hc8dfbb8_1         3.0 MB  conda-forge
        fontconfig-2.13.1          |    h86ecdb6_1001         340 KB  conda-forge
        freetype-2.10.1            |       he06d7ca_0         877 KB  conda-forge
        gettext-0.19.8.1           |    hc5be6a0_1002         3.6 MB  conda-forge
        glib-2.64.2                |       h6f030ca_0         3.4 MB  conda-forge
        icu-64.2                   |       he1b5a44_1        12.6 MB  conda-forge
        jpeg-9c                    |    h14c3975_1001         251 KB  conda-forge
        libblas-3.8.0              |      14_openblas          10 KB  conda-forge
        libcblas-3.8.0             |      14_openblas          10 KB  conda-forge
        libgfortran-ng-7.3.0       |       hdf63c60_5         1.7 MB  conda-forge
        libiconv-1.15              |    h516909a_1006         2.0 MB  conda-forge
        liblapack-3.8.0            |      14_openblas          10 KB  conda-forge
        libopenblas-0.3.7          |       h5ec1e0e_6         7.6 MB  conda-forge
        libpng-1.6.37              |       hed695b0_1         308 KB  conda-forge
        libtiff-4.1.0              |       hc7e4089_6         668 KB  conda-forge
        libuuid-2.32.1             |    h14c3975_1000          26 KB  conda-forge
        libwebp-base-1.1.0         |       h516909a_3         845 KB  conda-forge
        libxcb-1.13                |    h14c3975_1002         396 KB  conda-forge
        libxml2-2.9.10             |       hee79883_0         1.3 MB  conda-forge
        lz4-c-1.8.3                |    he1b5a44_1001         187 KB  conda-forge
        numpy-1.18.1               |   py37h8960a57_1         5.2 MB  conda-forge
        olefile-0.46               |             py_0          31 KB  conda-forge
        openssl-1.1.1g             |       h516909a_0         2.1 MB  conda-forge
        pandas-1.0.3               |   py37h0da4684_1        11.1 MB  conda-forge
        pcre-8.44                  |       he1b5a44_0         261 KB  conda-forge
        pillow-7.0.0               |   py37hb39fc2d_0         598 KB
        pixman-0.38.0              |    h516909a_1003         594 KB  conda-forge
        pthread-stubs-0.4          |    h14c3975_1001           5 KB  conda-forge
        pycairo-1.19.1             |   py37h01af8b0_3          77 KB  conda-forge
        python-dateutil-2.8.1      |             py_0         220 KB  conda-forge
        python_abi-3.7             |          1_cp37m           4 KB  conda-forge
        pytz-2019.3                |             py_0         237 KB  conda-forge
        rdkit-2020.03.1            |   py37hdd87690_1        24.7 MB  conda-forge
        xorg-kbproto-1.0.7         |    h14c3975_1002          26 KB  conda-forge
        xorg-libice-1.0.10         |       h516909a_0          57 KB  conda-forge
        xorg-libsm-1.2.3           |    h84519dc_1000          25 KB  conda-forge
        xorg-libx11-1.6.9          |       h516909a_0         918 KB  conda-forge
        xorg-libxau-1.0.9          |       h14c3975_0          13 KB  conda-forge
        xorg-libxdmcp-1.1.3        |       h516909a_0          18 KB  conda-forge
        xorg-libxext-1.3.4         |       h516909a_0          51 KB  conda-forge
        xorg-libxrender-0.9.10     |    h516909a_1002          31 KB  conda-forge
        xorg-renderproto-0.11.1    |    h14c3975_1002           8 KB  conda-forge
        xorg-xextproto-7.3.0       |    h14c3975_1002          27 KB  conda-forge
        xorg-xproto-7.0.31         |    h14c3975_1007          72 KB  conda-forge
        zstd-1.4.4                 |       h3b9ef0a_2         982 KB  conda-forge
        ------------------------------------------------------------
                                               Total:       110.7 MB
    
    The following NEW packages will be INSTALLED:
    
      boost              conda-forge/linux-64::boost-1.72.0-py37h9de70de_0
      boost-cpp          conda-forge/linux-64::boost-cpp-1.72.0-h8e57a91_0
      bzip2              conda-forge/linux-64::bzip2-1.0.8-h516909a_2
      cairo              conda-forge/linux-64::cairo-1.16.0-hcf35c78_1003
      fontconfig         conda-forge/linux-64::fontconfig-2.13.1-h86ecdb6_1001
      freetype           conda-forge/linux-64::freetype-2.10.1-he06d7ca_0
      gettext            conda-forge/linux-64::gettext-0.19.8.1-hc5be6a0_1002
      glib               conda-forge/linux-64::glib-2.64.2-h6f030ca_0
      icu                conda-forge/linux-64::icu-64.2-he1b5a44_1
      jpeg               conda-forge/linux-64::jpeg-9c-h14c3975_1001
      libblas            conda-forge/linux-64::libblas-3.8.0-14_openblas
      libcblas           conda-forge/linux-64::libcblas-3.8.0-14_openblas
      libgfortran-ng     conda-forge/linux-64::libgfortran-ng-7.3.0-hdf63c60_5
      libiconv           conda-forge/linux-64::libiconv-1.15-h516909a_1006
      liblapack          conda-forge/linux-64::liblapack-3.8.0-14_openblas
      libopenblas        conda-forge/linux-64::libopenblas-0.3.7-h5ec1e0e_6
      libpng             conda-forge/linux-64::libpng-1.6.37-hed695b0_1
      libtiff            conda-forge/linux-64::libtiff-4.1.0-hc7e4089_6
      libuuid            conda-forge/linux-64::libuuid-2.32.1-h14c3975_1000
      libwebp-base       conda-forge/linux-64::libwebp-base-1.1.0-h516909a_3
      libxcb             conda-forge/linux-64::libxcb-1.13-h14c3975_1002
      libxml2            conda-forge/linux-64::libxml2-2.9.10-hee79883_0
      lz4-c              conda-forge/linux-64::lz4-c-1.8.3-he1b5a44_1001
      numpy              conda-forge/linux-64::numpy-1.18.1-py37h8960a57_1
      olefile            conda-forge/noarch::olefile-0.46-py_0
      pandas             conda-forge/linux-64::pandas-1.0.3-py37h0da4684_1
      pcre               conda-forge/linux-64::pcre-8.44-he1b5a44_0
      pillow             pkgs/main/linux-64::pillow-7.0.0-py37hb39fc2d_0
      pixman             conda-forge/linux-64::pixman-0.38.0-h516909a_1003
      pthread-stubs      conda-forge/linux-64::pthread-stubs-0.4-h14c3975_1001
      pycairo            conda-forge/linux-64::pycairo-1.19.1-py37h01af8b0_3
      python-dateutil    conda-forge/noarch::python-dateutil-2.8.1-py_0
      python_abi         conda-forge/linux-64::python_abi-3.7-1_cp37m
      pytz               conda-forge/noarch::pytz-2019.3-py_0
      rdkit              conda-forge/linux-64::rdkit-2020.03.1-py37hdd87690_1
      xorg-kbproto       conda-forge/linux-64::xorg-kbproto-1.0.7-h14c3975_1002
      xorg-libice        conda-forge/linux-64::xorg-libice-1.0.10-h516909a_0
      xorg-libsm         conda-forge/linux-64::xorg-libsm-1.2.3-h84519dc_1000
      xorg-libx11        conda-forge/linux-64::xorg-libx11-1.6.9-h516909a_0
      xorg-libxau        conda-forge/linux-64::xorg-libxau-1.0.9-h14c3975_0
      xorg-libxdmcp      conda-forge/linux-64::xorg-libxdmcp-1.1.3-h516909a_0
      xorg-libxext       conda-forge/linux-64::xorg-libxext-1.3.4-h516909a_0
      xorg-libxrender    conda-forge/linux-64::xorg-libxrender-0.9.10-h516909a_1002
      xorg-renderproto   conda-forge/linux-64::xorg-renderproto-0.11.1-h14c3975_1002
      xorg-xextproto     conda-forge/linux-64::xorg-xextproto-7.3.0-h14c3975_1002
      xorg-xproto        conda-forge/linux-64::xorg-xproto-7.0.31-h14c3975_1007
      zstd               conda-forge/linux-64::zstd-1.4.4-h3b9ef0a_2
    
    The following packages will be UPDATED:
    
      ca-certificates     pkgs/main::ca-certificates-2020.1.1-0 --> conda-forge::ca-certificates-2020.4.5.1-hecc5488_0
      certifi              pkgs/main::certifi-2019.11.28-py37_0 --> conda-forge::certifi-2020.4.5.1-py37hc8dfbb8_0
      conda                       pkgs/main::conda-4.8.2-py37_0 --> conda-forge::conda-4.8.3-py37hc8dfbb8_1
      openssl              pkgs/main::openssl-1.1.1d-h7b6447c_4 --> conda-forge::openssl-1.1.1g-h516909a_0
    
    
    Preparing transaction: ...working... done
    Verifying transaction: ...working... done
    Executing transaction: ...working... done
    
    real	0m37.898s
    user	0m31.771s
    sys	0m3.614s

%% Cell type:code id: tags:

``` 
import torch
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem import rdFMCS
from matplotlib import colors
from rdkit.Chem import Draw
from rdkit.Chem.Draw import MolToImage
from PIL import Image


def get_mol(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    Chem.Kekulize(mol)
    return mol


def find_matches_one(mol,submol):
    #find all matching atoms for each submol in submol_list in mol.
    match_dict = {}
    mols = [mol,submol] #pairwise search
    res=rdFMCS.FindMCS(mols) #,ringMatchesRingOnly=True)
    mcsp = Chem.MolFromSmarts(res.smartsString)
    matches = mol.GetSubstructMatches(mcsp)
    return matches

#Draw the molecule
def get_image(mol,atomset):
    hcolor = colors.to_rgb('green')
    if atomset is not None:
        #highlight the atoms set while drawing the whole molecule.
        img = MolToImage(mol, size=(600, 600),fitImage=True, highlightAtoms=atomset,highlightColor=hcolor)
    else:
        img = MolToImage(mol, size=(400, 400),fitImage=True)
    return img
```

%% Cell type:code id: tags:

``` 
sequence = f"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC={tokenizer.mask_token}1"
substructure = "CC=CC"
image_list = []

input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]

token_logits = model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
  smi = (sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))
  print (smi)
  smi_mol = get_mol(smi)
  substructure_mol = get_mol(substructure)
  if smi_mol is None: # if the model's token prediction isn't chemically feasible
    continue
  Draw.MolToFile(smi_mol, smi+".png")
  matches = find_matches_one(smi_mol, substructure_mol)
  atomset = list(matches[0])
  img = get_image(smi_mol, atomset)
  img.format="PNG"
  image_list.append(img)
```

%% Output

    CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1
    CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1
    CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=N1
    CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=21
    CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=H1

%% Cell type:code id: tags:

``` 
from IPython.display import Image

for img in image_list:
  display(img)
```

%% Output



%% Cell type:markdown id: tags:

As we can see above, 2 of 4 of the model's MLM predictions are chemically valid. The one the model would've chosen (with a score of 0.6), is the first image, in which the top left molecular structure resembles the benzene found in the therapy Remdesivir. Overall, the model seems to understand syntax with a pretty decent degree of certainity.

However, further training on a more specific dataset (say leads for a specific target) may generate a stronger MLM model. Let's now fine-tune our model on a dataset of our choice, Tox21.

%% Cell type:markdown id: tags:

# Fine-tuning ChemBERTa on a Small Mollecular Dataset

Tumor suppressor protein (SR.p53), typically the p53 pathway is “off” and is activated when cells are under stress or damaged, hence being a good indicator of DNA damage and other cellular stresses. Tumor suppressor protein p53 is activated by inducing DNA repair, cell cycle arrest and apoptosis.

The Tox21 challenge was introduced in 2014 in an attempt to build models that are successful in predicting compounds' interference in biochemical pathways using only chemical structure data. The computational models produced from the challenge could become decision-making tools for government agencies in determining which environmental chemicals and drugs are of the greatest potential concern to human health. Additionally, these models can act as drug screening tools in the drug discovery pipelines for toxicity.

%% Cell type:markdown id: tags:

Lets start by loading the dataset from s3, before importing apex and transformers, the tool which will allow us to import the pre-trained masked-language modelling architecture trained on ZINC15.

%% Cell type:code id: tags:

``` 
!wget https://t.co/zrC7F8DcRs?amp=1
```

%% Output

    --2020-04-24 04:31:11--  https://t.co/zrC7F8DcRs?amp=1
    Resolving t.co (t.co)... 104.244.42.5, 104.244.42.133, 104.244.42.69, ...
    Connecting to t.co (t.co)|104.244.42.5|:443... connected.
    HTTP request sent, awaiting response... 301 Moved Permanently
    Location: https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21_balanced_revised_no_id.csv [following]
    --2020-04-24 04:31:12--  https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21_balanced_revised_no_id.csv
    Resolving deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)... 52.219.112.193
    Connecting to deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)|52.219.112.193|:443... connected.
    HTTP request sent, awaiting response... 200 OK
    Length: 85962 (84K) [text/csv]
    Saving to: ‘zrC7F8DcRs?amp=1’
    
    
zrC7F8DcRs?amp=1      0%[                    ]       0  --.-KB/s               
zrC7F8DcRs?amp=1     69%[============>       ]  58.65K   162KB/s               
zrC7F8DcRs?amp=1    100%[===================>]  83.95K   231KB/s    in 0.4s
    
    2020-04-24 04:31:13 (231 KB/s) - ‘zrC7F8DcRs?amp=1’ saved [85962/85962]
    

%% Cell type:markdown id: tags:

We want to install NVIDIA's Apex tool, for the training pipeline used by `simple-transformer` and Weights and Biases.

%% Cell type:code id: tags:

``` 
!git clone https://github.com/NVIDIA/apex
!cd /content/apex
!pip install -v --no-cache-dir /content/apex
!cd ..
```

%% Output

    Cloning into 'apex'...
    remote: Enumerating objects: 4, done.[K
    remote: Counting objects: 100% (4/4), done.[K
    remote: Compressing objects: 100% (4/4), done.[K
    remote: Total 6593 (delta 0), reused 0 (delta 0), pack-reused 6589[K
    Receiving objects: 100% (6593/6593), 13.70 MiB | 1.52 MiB/s, done.
    Resolving deltas: 100% (4383/4383), done.
    Created temporary directory: /tmp/pip-ephem-wheel-cache-q5nbg4uh
    Created temporary directory: /tmp/pip-req-tracker-ixo7f527
    Created requirements tracker '/tmp/pip-req-tracker-ixo7f527'
    Created temporary directory: /tmp/pip-install-4bvkod3b
    Processing ./apex
      Created temporary directory: /tmp/pip-req-build-nltce4xy
      Added file:///content/apex to build tracker '/tmp/pip-req-tracker-ixo7f527'
        Running setup.py (path:/tmp/pip-req-build-nltce4xy/setup.py) egg_info for package from file:///content/apex
        Running command python setup.py egg_info
        torch.__version__  =  1.4.0
        running egg_info
        creating /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info
        writing /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/PKG-INFO
        writing dependency_links to /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/dependency_links.txt
        writing top-level names to /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/top_level.txt
        writing manifest file '/tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/SOURCES.txt'
        writing manifest file '/tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/SOURCES.txt'
        /tmp/pip-req-build-nltce4xy/setup.py:46: UserWarning: Option --pyprof not specified. Not installing PyProf dependencies!
          warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")
      Source in /tmp/pip-req-build-nltce4xy has version 0.1, which satisfies requirement apex==0.1 from file:///content/apex
      Removed apex==0.1 from file:///content/apex from build tracker '/tmp/pip-req-tracker-ixo7f527'
    Building wheels for collected packages: apex
      Created temporary directory: /tmp/pip-wheel-rym8vea2
      Building wheel for apex (setup.py) ... [?25l  Destination directory: /tmp/pip-wheel-rym8vea2
      Running command /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-nltce4xy/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-nltce4xy/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-rym8vea2 --python-tag cp36
      torch.__version__  =  1.4.0
      /tmp/pip-req-build-nltce4xy/setup.py:46: UserWarning: Option --pyprof not specified. Not installing PyProf dependencies!
        warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")
      running bdist_wheel
      running build
      running build_py
      creating build
      creating build/lib
      creating build/lib/apex
      copying apex/__init__.py -> build/lib/apex
      creating build/lib/apex/optimizers
      copying apex/optimizers/__init__.py -> build/lib/apex/optimizers
      copying apex/optimizers/fused_lamb.py -> build/lib/apex/optimizers
      copying apex/optimizers/fused_sgd.py -> build/lib/apex/optimizers
      copying apex/optimizers/fused_novograd.py -> build/lib/apex/optimizers
      copying apex/optimizers/fused_adam.py -> build/lib/apex/optimizers
      creating build/lib/apex/pyprof
      copying apex/pyprof/__init__.py -> build/lib/apex/pyprof
      creating build/lib/apex/normalization
      copying apex/normalization/__init__.py -> build/lib/apex/normalization
      copying apex/normalization/fused_layer_norm.py -> build/lib/apex/normalization
      creating build/lib/apex/multi_tensor_apply
      copying apex/multi_tensor_apply/__init__.py -> build/lib/apex/multi_tensor_apply
      copying apex/multi_tensor_apply/multi_tensor_apply.py -> build/lib/apex/multi_tensor_apply
      creating build/lib/apex/parallel
      copying apex/parallel/optimized_sync_batchnorm_kernel.py -> build/lib/apex/parallel
      copying apex/parallel/optimized_sync_batchnorm.py -> build/lib/apex/parallel
      copying apex/parallel/__init__.py -> build/lib/apex/parallel
      copying apex/parallel/LARC.py -> build/lib/apex/parallel
      copying apex/parallel/sync_batchnorm_kernel.py -> build/lib/apex/parallel
      copying apex/parallel/distributed.py -> build/lib/apex/parallel
      copying apex/parallel/sync_batchnorm.py -> build/lib/apex/parallel
      copying apex/parallel/multiproc.py -> build/lib/apex/parallel
      creating build/lib/apex/fp16_utils
      copying apex/fp16_utils/__init__.py -> build/lib/apex/fp16_utils
      copying apex/fp16_utils/fp16util.py -> build/lib/apex/fp16_utils
      copying apex/fp16_utils/fp16_optimizer.py -> build/lib/apex/fp16_utils
      copying apex/fp16_utils/loss_scaler.py -> build/lib/apex/fp16_utils
      creating build/lib/apex/reparameterization
      copying apex/reparameterization/__init__.py -> build/lib/apex/reparameterization
      copying apex/reparameterization/reparameterization.py -> build/lib/apex/reparameterization
      copying apex/reparameterization/weight_norm.py -> build/lib/apex/reparameterization
      creating build/lib/apex/contrib
      copying apex/contrib/__init__.py -> build/lib/apex/contrib
      creating build/lib/apex/mlp
      copying apex/mlp/__init__.py -> build/lib/apex/mlp
      copying apex/mlp/mlp.py -> build/lib/apex/mlp
      creating build/lib/apex/RNN
      copying apex/RNN/cells.py -> build/lib/apex/RNN
      copying apex/RNN/__init__.py -> build/lib/apex/RNN
      copying apex/RNN/RNNBackend.py -> build/lib/apex/RNN
      copying apex/RNN/models.py -> build/lib/apex/RNN
      creating build/lib/apex/amp
      copying apex/amp/opt.py -> build/lib/apex/amp
      copying apex/amp/__init__.py -> build/lib/apex/amp
      copying apex/amp/scaler.py -> build/lib/apex/amp
      copying apex/amp/__version__.py -> build/lib/apex/amp
      copying apex/amp/_amp_state.py -> build/lib/apex/amp
      copying apex/amp/rnn_compat.py -> build/lib/apex/amp
      copying apex/amp/utils.py -> build/lib/apex/amp
      copying apex/amp/wrap.py -> build/lib/apex/amp
      copying apex/amp/_initialize.py -> build/lib/apex/amp
      copying apex/amp/frontend.py -> build/lib/apex/amp
      copying apex/amp/_process_optimizer.py -> build/lib/apex/amp
      copying apex/amp/compat.py -> build/lib/apex/amp
      copying apex/amp/amp.py -> build/lib/apex/amp
      copying apex/amp/handle.py -> build/lib/apex/amp
      creating build/lib/apex/pyprof/nvtx
      copying apex/pyprof/nvtx/__init__.py -> build/lib/apex/pyprof/nvtx
      copying apex/pyprof/nvtx/nvmarker.py -> build/lib/apex/pyprof/nvtx
      creating build/lib/apex/pyprof/parse
      copying apex/pyprof/parse/__main__.py -> build/lib/apex/pyprof/parse
      copying apex/pyprof/parse/__init__.py -> build/lib/apex/pyprof/parse
      copying apex/pyprof/parse/kernel.py -> build/lib/apex/pyprof/parse
      copying apex/pyprof/parse/db.py -> build/lib/apex/pyprof/parse
      copying apex/pyprof/parse/nvvp.py -> build/lib/apex/pyprof/parse
      copying apex/pyprof/parse/parse.py -> build/lib/apex/pyprof/parse
      creating build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/utility.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/prof.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/__main__.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/__init__.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/misc.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/conv.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/embedding.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/output.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/usage.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/index_slice_join_mutate.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/convert.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/pooling.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/linear.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/pointwise.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/activation.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/base.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/dropout.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/optim.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/data.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/reduction.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/blas.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/loss.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/softmax.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/randomSample.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/normalization.py -> build/lib/apex/pyprof/prof
      copying apex/pyprof/prof/recurrentCell.py -> build/lib/apex/pyprof/prof
      creating build/lib/apex/contrib/optimizers
      copying apex/contrib/optimizers/__init__.py -> build/lib/apex/contrib/optimizers
      copying apex/contrib/optimizers/fused_lamb.py -> build/lib/apex/contrib/optimizers
      copying apex/contrib/optimizers/fused_sgd.py -> build/lib/apex/contrib/optimizers
      copying apex/contrib/optimizers/fp16_optimizer.py -> build/lib/apex/contrib/optimizers
      copying apex/contrib/optimizers/fused_adam.py -> build/lib/apex/contrib/optimizers
      creating build/lib/apex/contrib/groupbn
      copying apex/contrib/groupbn/__init__.py -> build/lib/apex/contrib/groupbn
      copying apex/contrib/groupbn/batch_norm.py -> build/lib/apex/contrib/groupbn
      creating build/lib/apex/contrib/xentropy
      copying apex/contrib/xentropy/__init__.py -> build/lib/apex/contrib/xentropy
      copying apex/contrib/xentropy/softmax_xentropy.py -> build/lib/apex/contrib/xentropy
      creating build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/__init__.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/encdec_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/encdec_multihead_attn.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/fast_self_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/self_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/self_multihead_attn.py -> build/lib/apex/contrib/multihead_attn
      copying apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py -> build/lib/apex/contrib/multihead_attn
      creating build/lib/apex/amp/lists
      copying apex/amp/lists/__init__.py -> build/lib/apex/amp/lists
      copying apex/amp/lists/functional_overrides.py -> build/lib/apex/amp/lists
      copying apex/amp/lists/tensor_overrides.py -> build/lib/apex/amp/lists
      copying apex/amp/lists/torch_overrides.py -> build/lib/apex/amp/lists
      installing to build/bdist.linux-x86_64/wheel
      running install
      running install_lib
      creating build/bdist.linux-x86_64
      creating build/bdist.linux-x86_64/wheel
      creating build/bdist.linux-x86_64/wheel/apex
      creating build/bdist.linux-x86_64/wheel/apex/optimizers
      copying build/lib/apex/optimizers/__init__.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
      copying build/lib/apex/optimizers/fused_lamb.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
      copying build/lib/apex/optimizers/fused_sgd.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
      copying build/lib/apex/optimizers/fused_novograd.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
      copying build/lib/apex/optimizers/fused_adam.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
      copying build/lib/apex/__init__.py -> build/bdist.linux-x86_64/wheel/apex
      creating build/bdist.linux-x86_64/wheel/apex/pyprof
      creating build/bdist.linux-x86_64/wheel/apex/pyprof/nvtx
      copying build/lib/apex/pyprof/nvtx/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/nvtx
      copying build/lib/apex/pyprof/nvtx/nvmarker.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/nvtx
      copying build/lib/apex/pyprof/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof
      creating build/bdist.linux-x86_64/wheel/apex/pyprof/parse
      copying build/lib/apex/pyprof/parse/__main__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
      copying build/lib/apex/pyprof/parse/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
      copying build/lib/apex/pyprof/parse/kernel.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
      copying build/lib/apex/pyprof/parse/db.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
      copying build/lib/apex/pyprof/parse/nvvp.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
      copying build/lib/apex/pyprof/parse/parse.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
      creating build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/utility.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/prof.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/__main__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/misc.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/conv.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/embedding.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/output.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/usage.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/index_slice_join_mutate.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/convert.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/pooling.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/linear.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/pointwise.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/activation.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/base.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/dropout.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/optim.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/data.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/reduction.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/blas.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/loss.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/softmax.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/randomSample.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/normalization.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      copying build/lib/apex/pyprof/prof/recurrentCell.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
      creating build/bdist.linux-x86_64/wheel/apex/normalization
      copying build/lib/apex/normalization/__init__.py -> build/bdist.linux-x86_64/wheel/apex/normalization
      copying build/lib/apex/normalization/fused_layer_norm.py -> build/bdist.linux-x86_64/wheel/apex/normalization
      creating build/bdist.linux-x86_64/wheel/apex/multi_tensor_apply
      copying build/lib/apex/multi_tensor_apply/__init__.py -> build/bdist.linux-x86_64/wheel/apex/multi_tensor_apply
      copying build/lib/apex/multi_tensor_apply/multi_tensor_apply.py -> build/bdist.linux-x86_64/wheel/apex/multi_tensor_apply
      creating build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/optimized_sync_batchnorm_kernel.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/optimized_sync_batchnorm.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/__init__.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/LARC.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/sync_batchnorm_kernel.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/distributed.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/sync_batchnorm.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      copying build/lib/apex/parallel/multiproc.py -> build/bdist.linux-x86_64/wheel/apex/parallel
      creating build/bdist.linux-x86_64/wheel/apex/fp16_utils
      copying build/lib/apex/fp16_utils/__init__.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
      copying build/lib/apex/fp16_utils/fp16util.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
      copying build/lib/apex/fp16_utils/fp16_optimizer.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
      copying build/lib/apex/fp16_utils/loss_scaler.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
      creating build/bdist.linux-x86_64/wheel/apex/reparameterization
      copying build/lib/apex/reparameterization/__init__.py -> build/bdist.linux-x86_64/wheel/apex/reparameterization
      copying build/lib/apex/reparameterization/reparameterization.py -> build/bdist.linux-x86_64/wheel/apex/reparameterization
      copying build/lib/apex/reparameterization/weight_norm.py -> build/bdist.linux-x86_64/wheel/apex/reparameterization
      creating build/bdist.linux-x86_64/wheel/apex/contrib
      creating build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
      copying build/lib/apex/contrib/optimizers/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
      copying build/lib/apex/contrib/optimizers/fused_lamb.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
      copying build/lib/apex/contrib/optimizers/fused_sgd.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
      copying build/lib/apex/contrib/optimizers/fp16_optimizer.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
      copying build/lib/apex/contrib/optimizers/fused_adam.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
      copying build/lib/apex/contrib/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib
      creating build/bdist.linux-x86_64/wheel/apex/contrib/groupbn
      copying build/lib/apex/contrib/groupbn/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/groupbn
      copying build/lib/apex/contrib/groupbn/batch_norm.py -> build/bdist.linux-x86_64/wheel/apex/contrib/groupbn
      creating build/bdist.linux-x86_64/wheel/apex/contrib/xentropy
      copying build/lib/apex/contrib/xentropy/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/xentropy
      copying build/lib/apex/contrib/xentropy/softmax_xentropy.py -> build/bdist.linux-x86_64/wheel/apex/contrib/xentropy
      creating build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/encdec_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/encdec_multihead_attn.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/self_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/self_multihead_attn.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      copying build/lib/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
      creating build/bdist.linux-x86_64/wheel/apex/mlp
      copying build/lib/apex/mlp/__init__.py -> build/bdist.linux-x86_64/wheel/apex/mlp
      copying build/lib/apex/mlp/mlp.py -> build/bdist.linux-x86_64/wheel/apex/mlp
      creating build/bdist.linux-x86_64/wheel/apex/RNN
      copying build/lib/apex/RNN/cells.py -> build/bdist.linux-x86_64/wheel/apex/RNN
      copying build/lib/apex/RNN/__init__.py -> build/bdist.linux-x86_64/wheel/apex/RNN
      copying build/lib/apex/RNN/RNNBackend.py -> build/bdist.linux-x86_64/wheel/apex/RNN
      copying build/lib/apex/RNN/models.py -> build/bdist.linux-x86_64/wheel/apex/RNN
      creating build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/opt.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/__init__.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/scaler.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/__version__.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/_amp_state.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/rnn_compat.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/utils.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/wrap.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/_initialize.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/frontend.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/_process_optimizer.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/compat.py -> build/bdist.linux-x86_64/wheel/apex/amp
      creating build/bdist.linux-x86_64/wheel/apex/amp/lists
      copying build/lib/apex/amp/lists/__init__.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
      copying build/lib/apex/amp/lists/functional_overrides.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
      copying build/lib/apex/amp/lists/tensor_overrides.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
      copying build/lib/apex/amp/lists/torch_overrides.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
      copying build/lib/apex/amp/amp.py -> build/bdist.linux-x86_64/wheel/apex/amp
      copying build/lib/apex/amp/handle.py -> build/bdist.linux-x86_64/wheel/apex/amp
      running install_egg_info
      running egg_info
      creating apex.egg-info
      writing apex.egg-info/PKG-INFO
      writing dependency_links to apex.egg-info/dependency_links.txt
      writing top-level names to apex.egg-info/top_level.txt
      writing manifest file 'apex.egg-info/SOURCES.txt'
      writing manifest file 'apex.egg-info/SOURCES.txt'
      Copying apex.egg-info to build/bdist.linux-x86_64/wheel/apex-0.1-py3.6.egg-info
      running install_scripts
      adding license file "LICENSE" (matched pattern "LICEN[CS]E*")
      creating build/bdist.linux-x86_64/wheel/apex-0.1.dist-info/WHEEL
      creating '/tmp/pip-wheel-rym8vea2/apex-0.1-cp36-none-any.whl' and adding 'build/bdist.linux-x86_64/wheel' to it
      adding 'apex/__init__.py'
      adding 'apex/RNN/RNNBackend.py'
      adding 'apex/RNN/__init__.py'
      adding 'apex/RNN/cells.py'
      adding 'apex/RNN/models.py'
      adding 'apex/amp/__init__.py'
      adding 'apex/amp/__version__.py'
      adding 'apex/amp/_amp_state.py'
      adding 'apex/amp/_initialize.py'
      adding 'apex/amp/_process_optimizer.py'
      adding 'apex/amp/amp.py'
      adding 'apex/amp/compat.py'
      adding 'apex/amp/frontend.py'
      adding 'apex/amp/handle.py'
      adding 'apex/amp/opt.py'
      adding 'apex/amp/rnn_compat.py'
      adding 'apex/amp/scaler.py'
      adding 'apex/amp/utils.py'
      adding 'apex/amp/wrap.py'
      adding 'apex/amp/lists/__init__.py'
      adding 'apex/amp/lists/functional_overrides.py'
      adding 'apex/amp/lists/tensor_overrides.py'
      adding 'apex/amp/lists/torch_overrides.py'
      adding 'apex/contrib/__init__.py'
      adding 'apex/contrib/groupbn/__init__.py'
      adding 'apex/contrib/groupbn/batch_norm.py'
      adding 'apex/contrib/multihead_attn/__init__.py'
      adding 'apex/contrib/multihead_attn/encdec_multihead_attn.py'
      adding 'apex/contrib/multihead_attn/encdec_multihead_attn_func.py'
      adding 'apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py'
      adding 'apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py'
      adding 'apex/contrib/multihead_attn/fast_self_multihead_attn_func.py'
      adding 'apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py'
      adding 'apex/contrib/multihead_attn/self_multihead_attn.py'
      adding 'apex/contrib/multihead_attn/self_multihead_attn_func.py'
      adding 'apex/contrib/optimizers/__init__.py'
      adding 'apex/contrib/optimizers/fp16_optimizer.py'
      adding 'apex/contrib/optimizers/fused_adam.py'
      adding 'apex/contrib/optimizers/fused_lamb.py'
      adding 'apex/contrib/optimizers/fused_sgd.py'
      adding 'apex/contrib/xentropy/__init__.py'
      adding 'apex/contrib/xentropy/softmax_xentropy.py'
      adding 'apex/fp16_utils/__init__.py'
      adding 'apex/fp16_utils/fp16_optimizer.py'
      adding 'apex/fp16_utils/fp16util.py'
      adding 'apex/fp16_utils/loss_scaler.py'
      adding 'apex/mlp/__init__.py'
      adding 'apex/mlp/mlp.py'
      adding 'apex/multi_tensor_apply/__init__.py'
      adding 'apex/multi_tensor_apply/multi_tensor_apply.py'
      adding 'apex/normalization/__init__.py'
      adding 'apex/normalization/fused_layer_norm.py'
      adding 'apex/optimizers/__init__.py'
      adding 'apex/optimizers/fused_adam.py'
      adding 'apex/optimizers/fused_lamb.py'
      adding 'apex/optimizers/fused_novograd.py'
      adding 'apex/optimizers/fused_sgd.py'
      adding 'apex/parallel/LARC.py'
      adding 'apex/parallel/__init__.py'
      adding 'apex/parallel/distributed.py'
      adding 'apex/parallel/multiproc.py'
      adding 'apex/parallel/optimized_sync_batchnorm.py'
      adding 'apex/parallel/optimized_sync_batchnorm_kernel.py'
      adding 'apex/parallel/sync_batchnorm.py'
      adding 'apex/parallel/sync_batchnorm_kernel.py'
      adding 'apex/pyprof/__init__.py'
      adding 'apex/pyprof/nvtx/__init__.py'
      adding 'apex/pyprof/nvtx/nvmarker.py'
      adding 'apex/pyprof/parse/__init__.py'
      adding 'apex/pyprof/parse/__main__.py'
      adding 'apex/pyprof/parse/db.py'
      adding 'apex/pyprof/parse/kernel.py'
      adding 'apex/pyprof/parse/nvvp.py'
      adding 'apex/pyprof/parse/parse.py'
      adding 'apex/pyprof/prof/__init__.py'
      adding 'apex/pyprof/prof/__main__.py'
      adding 'apex/pyprof/prof/activation.py'
      adding 'apex/pyprof/prof/base.py'
      adding 'apex/pyprof/prof/blas.py'
      adding 'apex/pyprof/prof/conv.py'
      adding 'apex/pyprof/prof/convert.py'
      adding 'apex/pyprof/prof/data.py'
      adding 'apex/pyprof/prof/dropout.py'
      adding 'apex/pyprof/prof/embedding.py'
      adding 'apex/pyprof/prof/index_slice_join_mutate.py'
      adding 'apex/pyprof/prof/linear.py'
      adding 'apex/pyprof/prof/loss.py'
      adding 'apex/pyprof/prof/misc.py'
      adding 'apex/pyprof/prof/normalization.py'
      adding 'apex/pyprof/prof/optim.py'
      adding 'apex/pyprof/prof/output.py'
      adding 'apex/pyprof/prof/pointwise.py'
      adding 'apex/pyprof/prof/pooling.py'
      adding 'apex/pyprof/prof/prof.py'
      adding 'apex/pyprof/prof/randomSample.py'
      adding 'apex/pyprof/prof/recurrentCell.py'
      adding 'apex/pyprof/prof/reduction.py'
      adding 'apex/pyprof/prof/softmax.py'
      adding 'apex/pyprof/prof/usage.py'
      adding 'apex/pyprof/prof/utility.py'
      adding 'apex/reparameterization/__init__.py'
      adding 'apex/reparameterization/reparameterization.py'
      adding 'apex/reparameterization/weight_norm.py'
      adding 'apex-0.1.dist-info/LICENSE'
      adding 'apex-0.1.dist-info/METADATA'
      adding 'apex-0.1.dist-info/WHEEL'
      adding 'apex-0.1.dist-info/top_level.txt'
      adding 'apex-0.1.dist-info/RECORD'
      removing build/bdist.linux-x86_64/wheel
    [?25hdone
      Created wheel for apex: filename=apex-0.1-cp36-none-any.whl size=157194 sha256=9d3ee1058b54c45a2be01bb279de4ec903855ec8cf6d6f3d5559dffda5dfaf89
      Stored in directory: /tmp/pip-ephem-wheel-cache-q5nbg4uh/wheels/b1/3a/aa/d84906eaab780ae580c7a5686a33bf2820d8590ac3b60d5967
      Removing source in /tmp/pip-req-build-nltce4xy
    Successfully built apex
    Installing collected packages: apex
    
    Successfully installed apex-0.1
    Cleaning up...
    Removed build tracker '/tmp/pip-req-tracker-ixo7f527'

%% Cell type:code id: tags:

``` 
# Test if NVIDIA apex training tool works
from apex import amp
```

%% Cell type:markdown id: tags:

If you're only running the toxicity prediction portion of this tutorial, make sure you install transformers here. If you've ran all the cells before, you can ignore this install as we've already done `pip install transformers` before.

%% Cell type:code id: tags:

``` 
!pip install transformers
```

%% Output

    Collecting transformers
    [?25l  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)
    
[K     |▋                               | 10kB 23.7MB/s eta 0:00:01
[K     |█▏                              | 20kB 30.6MB/s eta 0:00:01
[K     |█▊                              | 30kB 34.7MB/s eta 0:00:01
[K     |██▎                             | 40kB 38.5MB/s eta 0:00:01
[K     |███                             | 51kB 39.1MB/s eta 0:00:01
[K     |███▌                            | 61kB 41.5MB/s eta 0:00:01
[K     |████                            | 71kB 32.6MB/s eta 0:00:01
[K     |████▋                           | 81kB 34.1MB/s eta 0:00:01
[K     |█████▎                          | 92kB 35.4MB/s eta 0:00:01
[K     |█████▉                          | 102kB 33.4MB/s eta 0:00:01
[K     |██████▍                         | 112kB 33.4MB/s eta 0:00:01
[K     |███████                         | 122kB 33.4MB/s eta 0:00:01
[K     |███████▋                        | 133kB 33.4MB/s eta 0:00:01
[K     |████████▏                       | 143kB 33.4MB/s eta 0:00:01
[K     |████████▊                       | 153kB 33.4MB/s eta 0:00:01
[K     |█████████▎                      | 163kB 33.4MB/s eta 0:00:01
[K     |█████████▉                      | 174kB 33.4MB/s eta 0:00:01
[K     |██████████▌                     | 184kB 33.4MB/s eta 0:00:01
[K     |███████████                     | 194kB 33.4MB/s eta 0:00:01
[K     |███████████▋                    | 204kB 33.4MB/s eta 0:00:01
[K     |████████████▏                   | 215kB 33.4MB/s eta 0:00:01
[K     |████████████▉                   | 225kB 33.4MB/s eta 0:00:01
[K     |█████████████▍                  | 235kB 33.4MB/s eta 0:00:01
[K     |██████████████                  | 245kB 33.4MB/s eta 0:00:01
[K     |██████████████▌                 | 256kB 33.4MB/s eta 0:00:01
[K     |███████████████▏                | 266kB 33.4MB/s eta 0:00:01
[K     |███████████████▊                | 276kB 33.4MB/s eta 0:00:01
[K     |████████████████▎               | 286kB 33.4MB/s eta 0:00:01
[K     |████████████████▉               | 296kB 33.4MB/s eta 0:00:01
[K     |█████████████████▍              | 307kB 33.4MB/s eta 0:00:01
[K     |██████████████████              | 317kB 33.4MB/s eta 0:00:01
[K     |██████████████████▋             | 327kB 33.4MB/s eta 0:00:01
[K     |███████████████████▏            | 337kB 33.4MB/s eta 0:00:01
[K     |███████████████████▊            | 348kB 33.4MB/s eta 0:00:01
[K     |████████████████████▍           | 358kB 33.4MB/s eta 0:00:01
[K     |█████████████████████           | 368kB 33.4MB/s eta 0:00:01
[K     |█████████████████████▌          | 378kB 33.4MB/s eta 0:00:01
[K     |██████████████████████          | 389kB 33.4MB/s eta 0:00:01
[K     |██████████████████████▊         | 399kB 33.4MB/s eta 0:00:01
[K     |███████████████████████▎        | 409kB 33.4MB/s eta 0:00:01
[K     |███████████████████████▉        | 419kB 33.4MB/s eta 0:00:01
[K     |████████████████████████▍       | 430kB 33.4MB/s eta 0:00:01
[K     |█████████████████████████       | 440kB 33.4MB/s eta 0:00:01
[K     |█████████████████████████▋      | 450kB 33.4MB/s eta 0:00:01
[K     |██████████████████████████▏     | 460kB 33.4MB/s eta 0:00:01
[K     |██████████████████████████▊     | 471kB 33.4MB/s eta 0:00:01
[K     |███████████████████████████▎    | 481kB 33.4MB/s eta 0:00:01
[K     |████████████████████████████    | 491kB 33.4MB/s eta 0:00:01
[K     |████████████████████████████▌   | 501kB 33.4MB/s eta 0:00:01
[K     |█████████████████████████████   | 512kB 33.4MB/s eta 0:00:01
[K     |█████████████████████████████▋  | 522kB 33.4MB/s eta 0:00:01
[K     |██████████████████████████████▎ | 532kB 33.4MB/s eta 0:00:01
[K     |██████████████████████████████▉ | 542kB 33.4MB/s eta 0:00:01
[K     |███████████████████████████████▍| 552kB 33.4MB/s eta 0:00:01
[K     |████████████████████████████████| 563kB 33.4MB/s eta 0:00:01
[K     |████████████████████████████████| 573kB 33.4MB/s
    [?25hRequirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
    Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)
    Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)
    Collecting tokenizers==0.5.2
    [?25l  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
    [K     |████████████████████████████████| 3.7MB 59.5MB/s
    [?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.40)
    Collecting sentencepiece
    [?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
    [K     |████████████████████████████████| 1.0MB 54.3MB/s
    [?25hCollecting sacremoses
    [?25l  Downloading https://files.pythonhosted.org/packages/99/50/93509f906a40bffd7d175f97fd75ea328ad9bd91f48f59c4bd084c94a25e/sacremoses-0.0.41.tar.gz (883kB)
    [K     |████████████████████████████████| 890kB 52.9MB/s
    [?25hRequirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
    Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.2)
    Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
    Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)
    Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
    Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)
    Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
    Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)
    Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)
    Requirement already satisfied: botocore<1.16.0,>=1.15.40 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.40)
    Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)
    Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.1)
    Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)
    Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (2.8.1)
    Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (0.15.2)
    Building wheels for collected packages: sacremoses
      Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
      Created wheel for sacremoses: filename=sacremoses-0.0.41-cp36-none-any.whl size=893334 sha256=5debc70bf2760c36e513997d6cfe94649592d95dd3b9654ec526cd2a32cad0f2
      Stored in directory: /root/.cache/pip/wheels/22/5a/d4/b020a81249de7dc63758a34222feaa668dbe8ebfe9170cc9b1
    Successfully built sacremoses
    Installing collected packages: tokenizers, sentencepiece, sacremoses, transformers
    Successfully installed sacremoses-0.0.41 sentencepiece-0.1.85 tokenizers-0.5.2 transformers-2.8.0

%% Cell type:code id: tags:

``` 
!pip install simpletransformers
!pip install wandb
```

%% Output

    Collecting simpletransformers
    [?25l  Downloading https://files.pythonhosted.org/packages/ce/cc/4b42c1c362c7c3b939ebf5b628145abf69aeb8e1ac3f79770577466319c1/simpletransformers-0.25.0-py3-none-any.whl (157kB)
    
[K     |██                              | 10kB 26.0MB/s eta 0:00:01
[K     |████▏                           | 20kB 30.8MB/s eta 0:00:01
[K     |██████▎                         | 30kB 27.9MB/s eta 0:00:01
[K     |████████▎                       | 40kB 30.3MB/s eta 0:00:01
[K     |██████████▍                     | 51kB 19.1MB/s eta 0:00:01
[K     |████████████▌                   | 61kB 17.8MB/s eta 0:00:01
[K     |██████████████▌                 | 71kB 16.2MB/s eta 0:00:01
[K     |████████████████▋               | 81kB 16.0MB/s eta 0:00:01
[K     |██████████████████▊             | 92kB 15.8MB/s eta 0:00:01
[K     |████████████████████▊           | 102kB 15.8MB/s eta 0:00:01
[K     |██████████████████████▉         | 112kB 15.8MB/s eta 0:00:01
[K     |█████████████████████████       | 122kB 15.8MB/s eta 0:00:01
[K     |███████████████████████████     | 133kB 15.8MB/s eta 0:00:01
[K     |█████████████████████████████   | 143kB 15.8MB/s eta 0:00:01
[K     |███████████████████████████████▏| 153kB 15.8MB/s eta 0:00:01
[K     |████████████████████████████████| 163kB 15.8MB/s
    [?25hRequirement already satisfied: tokenizers in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (0.5.2)
    Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (2.21.0)
    Collecting seqeval
      Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz
    Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (0.22.2.post1)
    Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (1.0.3)
    Collecting tensorboardx
    [?25l  Downloading https://files.pythonhosted.org/packages/35/f1/5843425495765c8c2dd0784a851a93ef204d314fc87bcc2bbb9f662a3ad1/tensorboardX-2.0-py2.py3-none-any.whl (195kB)
    [K     |████████████████████████████████| 204kB 26.6MB/s
    [?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (4.38.0)
    Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (1.4.1)
    Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (2019.12.20)
    Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (1.18.2)
    Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (2.8.0)
    Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (2020.4.5.1)
    Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (3.0.4)
    Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (2.8)
    Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (1.24.3)
    Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->simpletransformers) (2.3.1)
    Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->simpletransformers) (0.14.1)
    Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->simpletransformers) (2018.9)
    Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->simpletransformers) (2.8.1)
    Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tensorboardx->simpletransformers) (1.12.0)
    Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorboardx->simpletransformers) (3.10.0)
    Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (0.7)
    Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (1.12.40)
    Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (0.1.85)
    Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (0.0.41)
    Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (3.0.12)
    Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (2.10.0)
    Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (3.13)
    Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (1.1.0)
    Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (1.0.8)
    Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorboardx->simpletransformers) (46.1.3)
    Requirement already satisfied: botocore<1.16.0,>=1.15.40 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers->simpletransformers) (1.15.40)
    Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers->simpletransformers) (0.3.3)
    Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers->simpletransformers) (0.9.5)
    Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->simpletransformers) (7.1.1)
    Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers->simpletransformers) (0.15.2)
    Building wheels for collected packages: seqeval
      Building wheel for seqeval (setup.py) ... [?25l[?25hdone
      Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7424 sha256=9e0590e9e861f1347cbf412dcdb58e84cd910ed9897342f135f0caf8b102ccab
      Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68
    Successfully built seqeval
    Installing collected packages: seqeval, tensorboardx, simpletransformers
    Successfully installed seqeval-0.0.12 simpletransformers-0.25.0 tensorboardx-2.0
    Collecting wandb
    [?25l  Downloading https://files.pythonhosted.org/packages/68/dd/ce719d36c4172b56c7579a79fcfd2f731c386b39f258bb186ef17b73fd7d/wandb-0.8.32-py2.py3-none-any.whl (1.4MB)
    [K     |████████████████████████████████| 1.4MB 22.0MB/s
    [?25hRequirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.21.0)
    Collecting gql==0.2.0
      Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz
    Collecting docker-pycreds>=0.4.0
      Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl
    Collecting sentry-sdk>=0.4.0
    [?25l  Downloading https://files.pythonhosted.org/packages/20/7e/19545324e83db4522b885808cd913c3b93ecc0c88b03e037b78c6a417fa8/sentry_sdk-0.14.3-py2.py3-none-any.whl (103kB)
    [K     |████████████████████████████████| 112kB 49.4MB/s
    [?25hCollecting GitPython>=1.0.0
    [?25l  Downloading https://files.pythonhosted.org/packages/19/1a/0df85d2bddbca33665d2148173d3281b290ac054b5f50163ea735740ac7b/GitPython-3.1.1-py3-none-any.whl (450kB)
    [K     |████████████████████████████████| 460kB 55.5MB/s
    [?25hCollecting shortuuid>=0.5.0
      Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl
    Collecting watchdog>=0.8.3
    [?25l  Downloading https://files.pythonhosted.org/packages/73/c3/ed6d992006837e011baca89476a4bbffb0a91602432f73bd4473816c76e2/watchdog-0.10.2.tar.gz (95kB)
    [K     |████████████████████████████████| 102kB 13.4MB/s
    [?25hRequirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.1)
    Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)
    Collecting subprocess32>=3.5.3
    [?25l  Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)
    [K     |████████████████████████████████| 102kB 13.1MB/s
    [?25hCollecting configparser>=3.8.1
      Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl
    Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)
    Requirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)
    Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)
    Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)
    Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2020.4.5.1)
    Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.8)
    Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (1.24.3)
    Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)
    Collecting graphql-core<2,>=0.5.0
    [?25l  Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)
    [K     |████████████████████████████████| 71kB 9.5MB/s
    [?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)
    Collecting gitdb<5,>=4.0.1
    [?25l  Downloading https://files.pythonhosted.org/packages/74/52/ca35448b56c53a079d3ffe18b1978c6e424f6d4df02404877094c89f5bfb/gitdb-4.0.4-py3-none-any.whl (63kB)
    [K     |████████████████████████████████| 71kB 11.1MB/s
    [?25hCollecting pathtools>=0.1.1
      Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz
    Collecting smmap<4,>=3.0.1
      Downloading https://files.pythonhosted.org/packages/27/b1/e379cfb7c07bbf8faee29c4a1a2469dbea525f047c2b454c4afdefa20a30/smmap-3.0.2-py2.py3-none-any.whl
    Building wheels for collected packages: gql, watchdog, subprocess32, graphql-core, pathtools
      Building wheel for gql (setup.py) ... [?25l[?25hdone
      Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=88d395f05da00e481a02baafbdb85210b0ac9c1ebd46282674b22ba931f49b49
      Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23
      Building wheel for watchdog (setup.py) ... [?25l[?25hdone
      Created wheel for watchdog: filename=watchdog-0.10.2-cp36-none-any.whl size=73605 sha256=fc4714af40db86d8a9cc75d2531cb62ef7118b522f522607c9175386229ed6da
      Stored in directory: /root/.cache/pip/wheels/bc/ed/6c/028dea90d31b359cd2a7c8b0da4db80e41d24a59614154072e
      Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
      Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=48d3d5ee6e337e6149a50fc26fba262fbee92d8780e397d39ed5ef26ee59e6b4
      Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1
      Building wheel for graphql-core (setup.py) ... [?25l[?25hdone
      Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=fad902d3f06c9fc44b5fe47e103a279b23b3c3b71b3c75b8b1ae66e430e3ac61
      Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5
      Building wheel for pathtools (setup.py) ... [?25l[?25hdone
      Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=75048560edcc7c400832dc7e006d620847ca8c44625c452ac9a650464f549bd1
      Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843
    Successfully built gql watchdog subprocess32 graphql-core pathtools
    Installing collected packages: graphql-core, gql, docker-pycreds, sentry-sdk, smmap, gitdb, GitPython, shortuuid, pathtools, watchdog, subprocess32, configparser, wandb
    Successfully installed GitPython-3.1.1 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.4 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.14.3 shortuuid-1.0.1 smmap-3.0.2 subprocess32-3.5.4 wandb-0.8.32 watchdog-0.10.2

%% Cell type:markdown id: tags:

From here, we want to load the dataset from tox21 for training the model. We're going to use a filtered dataset of 2100 compounds, as there are only 400 positive leads and we want to avoid having a large data imbalance. We'll also use simple-transformer's `auto_weights` argument in defining our ChemBERTa model to do automatic weight balancing later on, to counteract this problem.


%% Cell type:code id: tags:

``` 
!cd ..
dataset_path = "/content/zrC7F8DcRs?amp=1"
df = pd.read_csv(dataset_path, sep = ',', warn_bad_lines=True, header=None)


df.rename(columns={0:'smiles',1:'labels'}, inplace=True)
df.head()
```

%% Output

                                smiles  labels
    0       CCCCCCCC/C=C\CCCCCCCC(N)=O       0
    1             CCCCCCOC(=O)c1ccccc1       0
    2    O=C(c1ccc(Cl)cc1)c1ccc(Cl)cc1       0
    3              COc1cc(Cl)c(OC)cc1N       0
    4  N[C@H](Cc1c[nH]c2ccccc12)C(=O)O       0

%% Cell type:markdown id: tags:

From here, lets set up a logger to record if any issues occur, and notify us if there are any problems with the arguments we've set for the model.

%% Cell type:code id: tags:

``` 
from simpletransformers.classification import ClassificationModel
import logging

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
```

%% Cell type:markdown id: tags:

Now, using `simple-transformer`, let's load the pre-trained model from HuggingFace's useful model-hub. We'll set the number of epochs to 3 in the arguments, but you can train for longer. Also make sure that `auto_weights` is set to True as we are dealing with imbalanced toxicity datasets.

%% Cell type:code id: tags:

``` 
model = ClassificationModel('roberta', 'seyonec/ChemBERTa-zinc-base-v1', args={'num_train_epochs': 3, 'auto_weights': True}) # You can set class weights by using the optional weight argument
```

%% Cell type:code id: tags:

``` 
# Split the train and test dataset 80-20

train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200).reset_index(drop=True)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)
```

%% Cell type:code id: tags:

``` 
# check if our train and evaluation dataframes are setup properly. There should only be two columns for the SMILES string and its corresponding label.

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))
```

%% Output

    FULL Dataset: (2142, 2)
    TRAIN Dataset: (1714, 2)
    TEST Dataset: (428, 2)

%% Cell type:markdown id: tags:

Now that we've set everything up, lets get to the fun part: training the model! We use Weights and Biases, which is optional (simply remove `wandb_project` from the list of args). Its a really useful tool for monitering the model's training results (such as accuracy, learning rate and loss), alongside with custom visualizations you can create as well as the gradients.

When you run this cell, Weights and Biases will ask for an account, which you can setup when you get a key through a Github account. Again, this is completely optional and it can be removed from the list of arguments.

%% Cell type:code id: tags:

``` 
# Create directory to store model weights (change path accordingly to where you want!)
!cd /content
!mkdir chemberta_tox21

# Train the model
model.train_model(train_dataset, output_dir='/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name'})
```

%% Output

    mkdir: cannot create directory ‘chemberta_tox21’: File exists

    /usr/local/lib/python3.6/dist-packages/simpletransformers/classification/classification_model.py:243: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels.
      "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
    INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


    
    Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.
    
    Defaults for this optimization level are:
    enabled                : True
    opt_level              : O1
    cast_model_type        : None
    patch_torch_functions  : True
    keep_batchnorm_fp32    : None
    master_weights         : None
    loss_scale             : dynamic
    Processing user overrides (additional kwargs that are not None)...
    After processing overrides, optimization options are:
    enabled                : True
    opt_level              : O1
    cast_model_type        : None
    patch_torch_functions  : True
    keep_batchnorm_fp32    : None
    master_weights         : None
    loss_scale             : dynamic



    wandb: ERROR Not authenticated.  Copy a key from https://app.wandb.ai/authorize

    API Key: ··········

    wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc


    INFO:wandb.run_manager:system metrics and metadata threads started
    INFO:wandb.run_manager:checking resume status, waiting at most 10 seconds
    INFO:wandb.run_manager:resuming run from id: UnVuOnYxOjEyMGVtb2htOnByb2plY3QtbmFtZTpzZXlvbmVj
    INFO:wandb.run_manager:upserting run before process can begin, waiting at most 10 seconds
    INFO:wandb.run_manager:saving patches
    INFO:wandb.run_manager:saving pip packages
    INFO:wandb.run_manager:initializing streaming files api
    INFO:wandb.run_manager:unblocking file change observer, beginning sync with W&B servers


    /usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:113: UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
      "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/config.yaml

    Running loss: 0.788242

    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/media/graph/graph_0_summary_0fce41b2.graph.json
    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl
    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/requirements.txt
    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/media/graph
    INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/media

    Running loss: 0.511236

    /usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:224: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
      warnings.warn("To get the last learning rate computed by the scheduler, "
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.602639

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.232389

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.853039

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl

    Running loss: 0.501909

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.066326


    Running loss: 0.170885

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.221705

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json

    Running loss: 0.207991

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.173742

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.456498

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl

    Running loss: 1.234981

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json

    Running loss: 0.397285

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.094101


    Running loss: 0.043023

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.053245

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json

    Running loss: 0.175583

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.182486

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl

    Running loss: 0.071419

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.565325

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json

    Running loss: 0.601075

    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json

    Running loss: 0.970592
    

    INFO:simpletransformers.classification.classification_model: Training of roberta model complete. Saved to /content/chemberta_tox21.
    INFO:wandb.run_manager:shutting down system stats and metadata service
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl
    INFO:wandb.run_manager:stopping streaming files and file change observer
    INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json

%% Cell type:markdown id: tags:

Let's install scikit-learn now, to evaluate the model we've trained.

%% Cell type:code id: tags:

``` 
!pip install -U scikit-learn
```

%% Output

    Requirement already up-to-date: scikit-learn in /usr/local/lib/python3.6/dist-packages (0.22.2.post1)
    Requirement already satisfied, skipping upgrade: numpy>=1.11.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.18.2)
    Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (0.14.1)
    Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.4.1)

%% Cell type:markdown id: tags:

The following cell can be ignored unless you are starting a new run-time and just want to load the model from your local directory.

%% Cell type:code id: tags:

``` 
# Loading a saved model for evaluation
model = ClassificationModel('roberta', '/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name','num_train_epochs': 3})
```

%% Cell type:code id: tags:

``` 
import sklearn
result, model_outputs, wrong_predictions = model.eval_model(test_dataset, acc=sklearn.metrics.accuracy_score)
```

%% Output

    /usr/local/lib/python3.6/dist-packages/simpletransformers/classification/classification_model.py:660: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels.
      "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
    INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


    


    INFO:simpletransformers.classification.classification_model:{'mcc': 0.7136017700095658, 'tp': 55, 'tn': 335, 'fp': 4, 'fn': 34, 'acc': 0.9112149532710281, 'eval_loss': 0.2323251810890657}

    

%% Cell type:markdown id: tags:

The model performs pretty well, averaging above 91% after training on only ~2000 data samples and 400 positive leads! We can clearly see the predictive power of transfer learning, and approaches like these are becoming increasing popular in the pharmaceutical industry where larger datasets are scarce. By training on more epochs and tasks, we can probably boost the accuracy as well!

Lets train the model on one last string outside of the filtered dataset for toxicity. The model should predict 0, meaning no interference in biochemical pathways for p53.

%% Cell type:code id: tags:

``` 
# Lets input a molecule with a SR-p53 value of 0
predictions, raw_outputs = model.predict(['CCCCOc1cc(C(=O)OCCN(CC)CC)ccc1N'])
```

%% Output

    INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


    


    

%% Cell type:code id: tags:

``` 
print(predictions)
print(raw_outputs)
```

%% Output

    [0]
    [[ 2.9023438 -2.859375 ]]

%% Cell type:markdown id: tags:

The model predicts the sample correctly! Some future tasks may include using the same model on multiple tasks (Tox21 provides multiple for toxicity), through multi-task classification, as well as training on a wider dataset. This will be expanded on in a future tutorial!

%% Cell type:markdown id: tags:

#Congratulations! Time to join the Community!
Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

# **Star DeepChem on [Github](https://github.com/deepchem/deepchem)**
This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

# **Join the DeepChem Gitter**
The DeepChem [Gitter](https://gitter.im/deepchem/Lobby) hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!