Commit 8ab94f29 authored by casey's avatar casey
Browse files

Add nb_epoch to PyGPGO

parent 64993a12
Loading
Loading
Loading
Loading
+24 −19
Original line number Diff line number Diff line
@@ -132,6 +132,7 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
                        valid_dataset: Dataset,
                        metric: Metric,
                        output_transformers: List[Transformer] = [],
                        nb_epoch: int = 5,
                        use_max: bool = True,
                        logdir: Optional[str] = None,
                        max_iter: int = 20,
@@ -230,22 +231,8 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
    # Stores all model locations
    model_locations = {}

    # Demarcating internal function for readability
    def optimizing_function(**placeholders):
      """Private Optimizing function

      Take in hyper parameter values and return valid set performances

      Parameters
      ----------
      placeholders: keyword arguments
        Should be various hyperparameters as specified in `param_keys` above.

      Returns:
      --------
      valid_scores: float
        valid set performances
      """
    # Private opt_func to pass nb_epoch for optimizing_function
    def _optimize(nb_epoch, **placeholders):
      hyper_parameters = {}
      for hp in param_keys:
        if param_range[hp][0] == "int":
@@ -277,7 +264,7 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
      # Add it on to the information needed for the constructor
      hyper_parameters["model_dir"] = model_dir
      model = self.model_builder(**hyper_parameters)
      model.fit(train_dataset)
      model.fit(train_dataset, nb_epoch=nb_epoch)
      try:
        model.save()
      # Some models autosave
@@ -305,6 +292,24 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
      else:
        return -score

    # Demarcating internal function for readability
    def optimizing_function(**placeholders):
      """Private Optimizing function

      Take in hyper parameter values and return valid set performances

      Parameters
      ----------
      placeholders: keyword arguments
        Should be various hyperparameters as specified in `param_keys` above.

      Returns:
      --------
      valid_scores: float
        valid set performances
      """
      return _optimize(nb_epoch=nb_epoch, **placeholders)
    
    # execute GPGO
    cov = matern32()
    gp = GaussianProcess(cov)