Commit abe2d888 authored by Arun's avatar Arun
Browse files

fixes attribute error from sklearn in print model

parent 22a8fbd7
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -175,6 +175,9 @@ class KerasModel(Model):
      the Weights & Biases logger object used to log data and metrics
    """
    super(KerasModel, self).__init__(model=model, model_dir=model_dir, **kwargs)
    self.loss = loss  # not used
    self.learning_rate = learning_rate  # not used
    self.output_types = output_types  # not used
    if isinstance(loss, Loss):
      self._loss_fn: LossFn = _StandardLoss(model, loss)
    else:
+3 −0
Original line number Diff line number Diff line
@@ -175,6 +175,9 @@ class TorchModel(Model):
      the Weights & Biases logger object used to log data and metrics
    """
    super(TorchModel, self).__init__(model=model, model_dir=model_dir, **kwargs)
    self.loss = loss  # not used
    self.learning_rate = learning_rate  # not used
    self.output_types = output_types  # not used
    if isinstance(loss, Loss):
      self._loss_fn: LossFn = _StandardLoss(self, loss)
    else: