Unverified Commit 8c0f9b52 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2146 from deepchem/docstorch

Adding docs for PyTorch Models
parents b1cac7b1 b71a8aa6
Loading
Loading
Loading
Loading
+64 −4
Original line number Diff line number Diff line
Model Classes
=============

DeepChem maintains an extensive collection of models for scientific applications.
DeepChem maintains an extensive collection of models for scientific
applications. DeepChem's focus is on facilitating scientific applications, so
we support a broad range of different machine learning frameworks (currently
scikit-learn, xgboost, TensorFlow, and PyTorch) since different frameworks are
more and less suited for different scientific applications.

Model Cheatsheet
----------------
@@ -125,6 +129,12 @@ read off what's needed to train the model from the table below.
+----------------------------------------+------------+----------------------+------------------------+----------------------------------------------------------------+----------------------+
| :code:`WGAN`                           | Adversarial| Pair                 |                        |                                                                | :code:`fit_gan`      |
+----------------------------------------+------------+----------------------+------------------------+----------------------------------------------------------------+----------------------+
| :code:`CGCNNModel`                     | Classifier/| :code:`GraphData`    |                        | :code:`CGCNNFeaturizer`                                        | :code:`fit`          |
|                                        | Regressor  |                      |                        |                                                                |                      |
+----------------------------------------+------------+----------------------+------------------------+----------------------------------------------------------------+----------------------+
| :code:`GATModel`                       | Classifier/| :code:`GraphData`    |                        | :code:`MolGraphConvFeaturizer`                                 | :code:`fit`          |
|                                        | Regressor  |                      |                        |                                                                |                      |
+----------------------------------------+------------+----------------------+------------------------+----------------------------------------------------------------+----------------------+

Model
-----
@@ -132,12 +142,24 @@ Model
.. autoclass:: deepchem.models.Model
  :members:

Scikit-Learn Models
===================

Scikit-learn's models can be wrapped so that they can interact conveniently
with DeepChem. Oftentimes scikit-learn models are more robust and easier to
train and are a nice first model to train.

SklearnModel
------------

.. autoclass:: deepchem.models.SklearnModel
  :members:

Xgboost Models
==============

Xgboost models can be wrapped so they can interact with DeepChem.

XGBoostModel
------------

@@ -145,9 +167,13 @@ XGBoostModel
  :members:


Keras Models
============
DeepChem extensively uses `Keras`_ to build powerful machine learning models.
Deep Learning Infrastructure
============================

DeepChem maintains a lightweight layer of common deep learning model
infrastructure that can be used for models built with different underlying
frameworks. The losses and optimizers can be used for both TensorFlow and
PyTorch models.

Losses
------
@@ -216,6 +242,12 @@ Optimizers
  :members:


Keras Models
============

DeepChem extensively uses `Keras`_ to build deep learning models.


KerasModel
----------

@@ -373,3 +405,31 @@ ChemCeption

.. autoclass:: deepchem.models.ChemCeption
  :members:

PyTorch Models
==============

DeepChem supports the use of `PyTorch`_ to build deep learning models.

.. _`PyTorch`: https://pytorch.org/ 

TorchModel
----------

You can wrap an arbitrary :code:`torch.nn.Module` in a :code:`TorchModel` object.

.. autoclass:: deepchem.models.TorchModel
  :members:

CGCNNModel
----------

.. autoclass:: deepchem.models.CGCNNModel
  :members:


GATModel
--------

.. autoclass:: deepchem.models.GATModel
  :members:
+2 −0
Original line number Diff line number Diff line
@@ -3,3 +3,5 @@ scikit-learn
sphinx_rtd_theme
tensorflow==2.2.0
transformers
xgboost
torch==1.6.0