Commit abf79c1c authored by casey's avatar casey
Browse files

Added nb_epochs to grid_search

parent 0f7332a9
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -66,6 +66,7 @@ class GridHyperparamOpt(HyperparamOpt):
      valid_dataset: Dataset,
      metric: Metric,
      output_transformers: List[Transformer] = [],
      nb_epoch: int = 5,
      use_max: bool = True,
      logdir: Optional[str] = None,
      **kwargs,
@@ -91,6 +92,7 @@ class GridHyperparamOpt(HyperparamOpt):
      `train_dataset` and `valid_dataset` may have been transformed
      for learning and need the transform to be inverted before
      the metric can be evaluated on a model.
    nb_epoch: int, (default 5)
    use_max: bool, optional
      If True, return the model with the highest score. Else return
      model with the minimum score.
@@ -144,7 +146,7 @@ class GridHyperparamOpt(HyperparamOpt):
        model_dir = tempfile.mkdtemp()
      model_params['model_dir'] = model_dir
      model = self.model_builder(**model_params)
      model.fit(train_dataset)
      model.fit(train_dataset, nb_epoch=nb_epoch)
      try:
        model.save()
      # Some models autosave