Commit 6b938e17 authored by nd-02110114's avatar nd-02110114
Browse files

revert

parent 2d3b9000
Loading
Loading
Loading
Loading
+2 −4
Original line number Diff line number Diff line
@@ -838,7 +838,7 @@ class TorchModel(Model):
    inputs = [
        x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs
    ]
    inputs = [torch.as_tensor(x, device=self.device).float() for x in inputs]
    inputs = [torch.as_tensor(x, device=self.device) for x in inputs]
    if labels is not None:
      labels = [
          x.astype(np.float32) if x.dtype == np.float64 else x for x in labels
@@ -848,9 +848,7 @@ class TorchModel(Model):
      weights = [
          x.astype(np.float32) if x.dtype == np.float64 else x for x in weights
      ]
      weights = [
          torch.as_tensor(x, device=self.device).float() for x in weights
      ]
      weights = [torch.as_tensor(x, device=self.device) for x in weights]

    return (inputs, labels, weights)