Commit ef070999 authored by nd-02110114's avatar nd-02110114 Committed by Bharath Ramsundar
Browse files

♻️ refactor

parent cfcc7a4a
Loading
Loading
Loading
Loading
+16 −56
Original line number Diff line number Diff line
@@ -145,9 +145,6 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
    string representations of hyperparameter sets to validation
    scores.
    """
    if len(params_dict) > 20:
      raise ValueError(
          "This class can only search over 20 parameters in one invocation.")
    # Specify logfile
    if logfile:
      log_file = logfile
@@ -156,48 +153,16 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
    else:
      log_file = None

    # setup range
    param_range = compute_parameter_range(params_dict, search_range)
    param_range_keys = list(param_range.keys())
    param_range_values = [param_range[key] for key in param_range_keys]

    # Number of parameters
    n_param = 0
    for val in params_dict.items():
      if isinstance(val, list):
        n_param += len(val)
      else:
        n_param += 1

    # Dummy names
    param_name = ['l' + format(i, '02d') for i in range(20)]
    # This is the dictionary of arguments we'll pass to pyGPGO
    param = dict(zip(param_name[:n_param], param_range_values))
    param_keys = list(param_range.keys())

    # Stores all results
    all_results = {}

    # Demarcating internal function for readability
    ########################
    def f(l00=0,
          l01=0,
          l02=0,
          l03=0,
          l04=0,
          l05=0,
          l06=0,
          l07=0,
          l08=0,
          l09=0,
          l10=0,
          l11=0,
          l12=0,
          l13=0,
          l14=0,
          l15=0,
          l16=0,
          l17=0,
          l18=0,
          l19=0):
    def f(**placeholders):
      """Private Optimizing function

      Take in hyper parameter values and return valid set performances
@@ -214,17 +179,13 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
        valid set performances
      """
      hyper_parameters = {}
      # This is a dictionary of form {'l01': val, ...} binding
      # arguments
      args = locals()
      # This bit of code re-associates hyperparameter values to their
      # names from the arguments of this local function.
      for i, hp in enumerate(param_range_keys):
        if isinstance(params_dict[hp], int):
          hyper_parameters[hp] = int(args[param_name[i]])
        elif isinstance(params_dict[hp], float):
          hyper_parameters[hp] = float(args[param_name[i]])

      for hp in param_keys:
        if param_range[hp][0] == "int":
          # param values are always float in BO, so this line converts float to int
          # see : https://github.com/josejimenezluna/pyGPGO/issues/10
          hyper_parameters[hp] = int(placeholders[hp])
        else:
          hyper_parameters[hp] = float(placeholders[hp])
      logger.info("Running hyperparameter set: %s" % str(hyper_parameters))
      if log_file:
        # Run benchmark
@@ -283,18 +244,17 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
    cov = matern32()
    gp = GaussianProcess(cov)
    acq = Acquisition(mode='ExpectedImprovement')
    gpgo = GPGO(gp, acq, f, param)
    gpgo = GPGO(gp, acq, f, param_range)
    logger.info("Max number of iteration: %i" % max_iter)
    gpgo.run(max_iter=max_iter)

    hp_opt, valid_performance_opt = gpgo.getResult()
    # Readout best hyper parameters
    hyper_parameters = {}
    for i, hp in enumerate(param_range_keys):
      if isinstance(params_dict[hp], int):
        hyper_parameters[hp] = int(hp_opt[param_name[i]])
      elif isinstance(params_dict[hp], float):
        hyper_parameters[hp] = float(hp_opt[param_name[i]])
    for hp in param_keys:
      if param_range[hp][0] == "int":
        hyper_parameters[hp] = int(hp_opt[hp])
      else:
        hyper_parameters[hp] = float(hp_opt[hp])
    hp_str = _convert_hyperparam_dict_to_filename(hyper_parameters)
    model_dir = "model%s" % hp_str
    hyper_parameters["model_dir"] = model_dir