Commit 8f30a8a5 authored by nd-02110114's avatar nd-02110114
Browse files

🐛 fix inconsistent API for HyperparamOpt

parent cd7d2c1b
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -233,7 +233,7 @@ class MolecularFeaturizer(Featurizer):
  The subclasses of this class require RDKit to be installed.
  """

  def featurize(self, molecules, log_every_n=1000):
  def featurize(self, molecules, log_every_n=1000) -> np.ndarray:
    """Calculate features for molecules.

    Parameters
@@ -315,7 +315,7 @@ class MaterialStructureFeaturizer(Featurizer):
  """

  def featurize(self,
                structures: Iterable[Union[Dict[str, Any], PymatgenStructure]],
                structures: Iterable[Union[Dict, PymatgenStructure]],
                log_every_n: int = 1000) -> np.ndarray:
    """Calculate features for crystal structures.

+16 −10
Original line number Diff line number Diff line
import logging
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

from deepchem.data import Dataset
from deepchem.trans import Transformer
from deepchem.models import Model
from deepchem.metrics import Metric

@@ -73,15 +74,15 @@ class HyperparamOpt(object):
          You probably want to instantiate a concrete subclass instead.")
    self.model_builder = model_builder

  def hyperparam_search(
      self,
      params_dict: Dict[str, Any],
  def hyperparam_search(self,
                        params_dict: Dict,
                        train_dataset: Dataset,
                        valid_dataset: Dataset,
                        output_transformers: List[Transformer],
                        metric: Metric,
                        use_max: bool = True,
                        logdir: Optional[str] = None,
      **kwargs) -> Tuple[Model, Dict[str, Any], Dict[str, float]]:
                        **kwargs) -> Tuple[Model, Dict, Dict]:
    """Conduct Hyperparameter search.

    This method defines the common API shared by all hyperparameter
@@ -102,6 +103,11 @@ class HyperparamOpt(object):
      dataset used for training
    valid_dataset: Dataset
      dataset used for validation(optimization on valid scores)
    output_transformers: list[Transformer]
      Transformers for evaluation. This argument is needed since
      `train_dataset` and `valid_dataset` may have been transformed
      for learning and need the transform to be inverted before
      the metric can be evaluated on a model.
    metric: Metric
      metric used for evaluation
    use_max: bool, optional
+24 −17
Original line number Diff line number Diff line
@@ -7,16 +7,17 @@ import tempfile
from typing import Dict, List, Optional, Tuple, Union

from deepchem.data import Dataset
from deepchem.trans import Transformer
from deepchem.metrics import Metric
from deepchem.utils.evaluate import Evaluator
from deepchem.hyper.base_classes import HyperparamOpt
from deepchem.hyper.base_classes import _convert_hyperparam_dict_to_filename

logger = logging.getLogger(__name__)
PARAM_DICT = Dict[str, Union[int, float]]


def compute_parameter_range(params_dict: PARAM_DICT,
                            search_range: Union[int, float, PARAM_DICT]
def compute_parameter_range(params_dict: Dict,
                            search_range: Union[int, float, Dict]
                           ) -> Dict[str, Tuple[str, List[float]]]:
  """Convenience Function to compute parameter search space.

@@ -126,19 +127,18 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
  This class requires pyGPGO to be installed.
  """

  # NOTE: mypy prohibits changing the number of arguments
  # FIXME: Signature of "hyperparam_search" incompatible with supertype "HyperparamOpt"
  def hyperparam_search(  # type: ignore[override]
      self,
      params_dict: PARAM_DICT,
  def hyperparam_search(self,
                        params_dict: Dict,
                        train_dataset: Dataset,
                        valid_dataset: Dataset,
                        output_transformers: List[Transformer],
                        metric: Metric,
                        use_max: bool = True,
                        logdir: Optional[str] = None,
                        max_iter: int = 20,
      search_range: Union[int, float, PARAM_DICT] = 4,
      logfile: Optional[str] = None):
                        search_range: Union[int, float, Dict] = 4,
                        logfile: Optional[str] = None,
                        **kwargs):
    """Perform hyperparameter search using a gaussian process.

    Parameters
@@ -154,6 +154,11 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
      dataset used for training
    valid_dataset: Dataset
      dataset used for validation(optimization on valid scores)
    output_transformers: list[Transformer]
      Transformers for evaluation. This argument is needed since
      `train_dataset` and `valid_dataset` may have been transformed
      for learning and need the transform to be inverted before
      the metric can be evaluated on a model.
    metric: Metric
      metric used for evaluation
    use_max: bool, (default True)
@@ -280,7 +285,9 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
      except NotImplementedError:
        pass

      multitask_scores = model.evaluate(valid_dataset, [metric])
      # multitask_scores = model.evaluate(valid_dataset, [metric])
      evaluator = Evaluator(model, valid_dataset, output_transformers)
      multitask_scores = evaluator.compute_model_performance([metric])
      score = multitask_scores[metric.name]

      if log_file:
+5 −6
Original line number Diff line number Diff line
@@ -60,17 +60,16 @@ class GridHyperparamOpt(HyperparamOpt):

  """

  # NOTE: mypy prohibits changing the number of arguments
  # FIXME: Signature of "hyperparam_search" incompatible with supertype "HyperparamOpt"
  def hyperparam_search(  # type: ignore[override]
  def hyperparam_search(
      self,
      params_dict: Dict[str, List],
      params_dict: Dict,
      train_dataset: Dataset,
      valid_dataset: Dataset,
      output_transformers: List[Transformer],
      metric: Metric,
      use_max: bool = True,
      logdir: Optional[str] = None,
      **kwargs,
  ):
    """Perform hyperparams search according to params_dict.

@@ -156,7 +155,7 @@ class GridHyperparamOpt(HyperparamOpt):
      evaluator = Evaluator(model, valid_dataset, output_transformers)
      multitask_scores = evaluator.compute_model_performance([metric])
      # NOTE: this casting is workaround. This line doesn't effect anything to the runtime
      multitask_scores = cast(Dict[str, float], multitask_scores)
      multitask_scores = cast(Dict, multitask_scores)
      valid_score = multitask_scores[metric.name]
      hp_str = _convert_hyperparam_dict_to_filename(hyper_params)
      all_scores[hp_str] = valid_score
@@ -183,7 +182,7 @@ class GridHyperparamOpt(HyperparamOpt):
    train_evaluator = Evaluator(best_model, train_dataset, output_transformers)
    multitask_scores = train_evaluator.compute_model_performance([metric])
    # NOTE: this casting is workaround. This line doesn't effect anything to the runtime
    multitask_scores = cast(Dict[str, float], multitask_scores)
    multitask_scores = cast(Dict, multitask_scores)
    train_score = multitask_scores[metric.name]
    logger.info("Best hyperparameters: %s" % str(best_hyperparams))
    logger.info("train_score: %f" % train_score)
+14 −3
Original line number Diff line number Diff line
@@ -42,7 +42,12 @@ class TestGaussianHyperparamOpt(unittest.TestCase):
    metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)

    best_model, best_hyperparams, all_results = optimizer.hyperparam_search(
        params_dict, self.train_dataset, self.valid_dataset, metric, max_iter=2)
        params_dict,
        self.train_dataset,
        self.valid_dataset,
        transformers,
        metric,
        max_iter=2)

    valid_score = best_model.evaluate(self.valid_dataset, [metric],
                                      transformers)
@@ -61,6 +66,7 @@ class TestGaussianHyperparamOpt(unittest.TestCase):
        params_dict,
        self.train_dataset,
        self.valid_dataset,
        transformers,
        metric,
        use_max=False,
        max_iter=2)
@@ -81,6 +87,7 @@ class TestGaussianHyperparamOpt(unittest.TestCase):
          params_dict,
          self.train_dataset,
          self.valid_dataset,
          transformers,
          metric,
          logdir=tmpdirname,
          max_iter=2)
@@ -99,6 +106,7 @@ class TestGaussianHyperparamOpt(unittest.TestCase):
        np.arange(10))
    valid_dataset = dc.data.NumpyDataset(
        np.random.rand(5, 3), np.zeros((5, 2)), np.ones((5, 2)), np.arange(5))
    transformers = []

    optimizer = dc.hyper.GaussianProcessHyperparamOpt(
        lambda **params: dc.models.MultitaskRegressor(n_tasks=2,
@@ -114,11 +122,12 @@ class TestGaussianHyperparamOpt(unittest.TestCase):
        params_dict,
        train_dataset,
        valid_dataset,
        transformers,
        metric,
        max_iter=1,
        use_max=False)

    valid_score = best_model.evaluate(valid_dataset, [metric])
    valid_score = best_model.evaluate(valid_dataset, [metric], transformers)
    assert valid_score["mean-mean_squared_error"] == min(all_results.values())
    assert valid_score["mean-mean_squared_error"] > 0

@@ -132,6 +141,7 @@ class TestGaussianHyperparamOpt(unittest.TestCase):
        np.arange(10))
    valid_dataset = dc.data.NumpyDataset(
        np.random.rand(5, 3), np.zeros((5, 2)), np.ones((5, 2)), np.arange(5))
    transformers = []

    optimizer = dc.hyper.GaussianProcessHyperparamOpt(
        lambda **params: dc.models.MultitaskRegressor(
@@ -152,12 +162,13 @@ class TestGaussianHyperparamOpt(unittest.TestCase):
          params_dict,
          train_dataset,
          valid_dataset,
          transformers,
          metric,
          max_iter=2,
          logdir=tmpdirname,
          search_range=search_range,
          use_max=False)
      valid_score = best_model.evaluate(valid_dataset, [metric])
      valid_score = best_model.evaluate(valid_dataset, [metric], transformers)
    # Test that 2 parameters were optimized
    for hp_str in all_results.keys():
      # Recall that the key is a string of the form _batch_size_39_learning_rate_0.01 for example
Loading