Commit 7365535e authored by peastman's avatar peastman
Browse files

More type annotations

parent d8a17537
Loading
Loading
Loading
Loading
+191 −150

File changed.

Preview size limit exceeded, changes collapsed.

+15 −18
Original line number Diff line number Diff line
@@ -18,11 +18,10 @@ from deepchem.models.optimizers import Adam, Optimizer, LearningRateSchedule
from deepchem.trans import Transformer, undo_transforms
from deepchem.utils.evaluate import GeneratorEvaluator

from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from deepchem.utils.typing import OneOrMany

LossFunction = Callable[[List, List, List], float]
T = TypeVar("T")
OneOrMany = Union[T, Sequence[T]]


class KerasModel(Model):
@@ -112,7 +111,7 @@ class KerasModel(Model):
               optimizer: Optional[Optimizer] = None,
               tensorboard: bool = False,
               log_frequency: int = 100,
               **kwargs):
               **kwargs) -> None:
    """Create a new KerasModel.

    Parameters
@@ -246,8 +245,7 @@ class KerasModel(Model):
          restore: bool = False,
          variables: Optional[List[tf.Variable]] = None,
          loss: Optional[LossFunction] = None,
          callbacks: Union[Callable, List[Callable]] = [],
          **kwargs) -> float:
          callbacks: Union[Callable, List[Callable]] = []) -> float:
    """Train this model on a dataset.

    Parameters
@@ -645,11 +643,11 @@ class KerasModel(Model):
    """
    return self._predict(generator, transformers, outputs, False, output_types)

  def predict_on_batch(self,
  def predict_on_batch(
      self,
      X: Sequence,
      transformers: List[Transformer] = [],
                       outputs: Optional[OneOrMany[tf.Tensor]] = None,
                       **kwargs) -> OneOrMany[np.ndarray]:
      outputs: Optional[OneOrMany[tf.Tensor]] = None) -> OneOrMany[np.ndarray]:
    """Generates predictions for input samples, processing samples in a batch.

    Parameters
@@ -813,12 +811,11 @@ class KerasModel(Model):
    else:
      return list(zip(output, std))

  def evaluate_generator(
      self,
  def evaluate_generator(self,
                         generator: Iterable[Tuple[Any, Any, Any]],
                         metrics: List[Metric],
                         transformers: List[Transformer] = [],
      per_task_metrics: bool = False) -> Dict[str, np.ndarray]:
                         per_task_metrics: bool = False):
    """Evaluate the performance of this model on the data produced by a generator.

    Parameters
@@ -1149,7 +1146,7 @@ class KerasModel(Model):
class _StandardLoss(object):
  """The implements the loss function for models that use a dc.models.losses.Loss."""

  def __init__(self, model: tf.keras.Model, loss: Loss):
  def __init__(self, model: tf.keras.Model, loss: Loss) -> None:
    self.model = model
    self.loss = loss

+64 −20
Original line number Diff line number Diff line
@@ -16,12 +16,16 @@ import sklearn
from sklearn.base import BaseEstimator

from deepchem.data import Dataset, pad_features
from deepchem.trans import undo_transforms
from deepchem.metrics import Metric
from deepchem.trans import Transformer, undo_transforms
from deepchem.utils.save import load_from_disk
from deepchem.utils.save import save_to_disk
from deepchem.utils.save import log
from deepchem.utils.evaluate import Evaluator

from typing import Any, Dict, List, Optional, Sequence
from deepchem.utils.typing import OneOrMany


class Model(BaseEstimator):
  """
@@ -29,10 +33,10 @@ class Model(BaseEstimator):
  """

  def __init__(self,
               model_instance=None,
               model_dir=None,
               verbose=True,
               **kwargs):
               model_instance: Optional[Any] = None,
               model_dir: Optional[str] = None,
               verbose: bool = True,
               **kwargs) -> None:
    """Abstract class for all models.
    Parameters:
    -----------
@@ -58,14 +62,26 @@ class Model(BaseEstimator):
    if 'model_dir_is_temp' in dir(self) and self.model_dir_is_temp:
      shutil.rmtree(self.model_dir)

  def fit_on_batch(self, X, y, w):
    """
    Updates existing model with new information.
  def fit_on_batch(self, X: Sequence, y: Sequence, w: Sequence) -> float:
    """Perform a single step of training.

    Parameters
    ----------
    X: ndarray
      the inputs for the batch
    y: ndarray
      the labels for the batch
    w: ndarray
      the weights for the batch

    Returns
    -------
    the loss on the batch
    """
    raise NotImplementedError(
        "Each model is responsible for its own fit_on_batch method.")

  def predict_on_batch(self, X, **kwargs):
  def predict_on_batch(self, X: Sequence):
    """
    Makes predictions on given batch of new data.

@@ -77,7 +93,7 @@ class Model(BaseEstimator):
    raise NotImplementedError(
        "Each model is responsible for its own predict_on_batch method.")

  def reload(self):
  def reload(self) -> None:
    """
    Reload trained model from disk.
    """
@@ -85,29 +101,40 @@ class Model(BaseEstimator):
        "Each model is responsible for its own reload method.")

  @staticmethod
  def get_model_filename(model_dir):
  def get_model_filename(model_dir: str) -> str:
    """
    Given model directory, obtain filename for the model itself.
    """
    return os.path.join(model_dir, "model.joblib")

  @staticmethod
  def get_params_filename(model_dir):
  def get_params_filename(model_dir: str) -> str:
    """
    Given model directory, obtain filename for the model itself.
    """
    return os.path.join(model_dir, "model_params.joblib")

  def save(self):
  def save(self) -> None:
    """Dispatcher function for saving.

    Each subclass is responsible for overriding this method.
    """
    raise NotImplementedError

  def fit(self, dataset, nb_epoch=10, **kwargs):
  def fit(self, dataset: Dataset, nb_epoch: int = 10) -> float:
    """
    Fits a model on data in a Dataset object.

    Parameters
    ----------
    dataset: Dataset
      the Dataset to train on
    nb_epoch: int
      the number of epochs to train for

    Returns
    -------
    the average loss over the most recent epoch
    """
    # TODO(rbharath/enf): We need a structured way to deal with potential GPU
    #                     memory overflows.
@@ -118,13 +145,26 @@ class Model(BaseEstimator):
        losses.append(self.fit_on_batch(X_batch, y_batch, w_batch))
      log("Avg loss for epoch %d: %f" % (epoch + 1, np.array(losses).mean()),
          self.verbose)
    return np.array(losses).mean()

  def predict(self, dataset, transformers=[]):
  def predict(self, dataset: Dataset,
              transformers: List[Transformer] = []) -> OneOrMany[np.ndarray]:
    """
    Uses self to make predictions on provided Dataset object.

    Returns:
      y_pred: numpy ndarray of shape (n_samples,)

    Parameters
    ----------
    dataset: dc.data.Dataset
      Dataset to make prediction on
    transformers: list of dc.trans.Transformers
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays
    if it produces multiple outputs
    """
    y_preds = []
    n_tasks = self.get_num_tasks()
@@ -140,7 +180,11 @@ class Model(BaseEstimator):
    y_pred = np.concatenate(y_preds)
    return y_pred

  def evaluate(self, dataset, metrics, transformers=[], per_task_metrics=False):
  def evaluate(self,
               dataset: Dataset,
               metrics: List[Metric],
               transformers: List[Transformer] = [],
               per_task_metrics: bool = False):
    """
    Evaluates the performance of this model on specified dataset.

@@ -169,13 +213,13 @@ class Model(BaseEstimator):
          metrics, per_task_metrics=per_task_metrics)
      return scores, per_task_scores

  def get_task_type(self):
  def get_task_type(self) -> str:
    """
    Currently models can only be classifiers or regressors.
    """
    raise NotImplementedError

  def get_num_tasks(self):
  def get_num_tasks(self) -> int:
    """
    Get number of tasks.
    """
+7 −0
Original line number Diff line number Diff line
"""Type annotations that are widely used in DeepChem"""

from typing import Sequence, Tuple, TypeVar, Union

T = TypeVar("T")
OneOrMany = Union[T, Sequence[T]]
Shape = Tuple[int, ...]