Commit 3e03ad9a authored by nd-02110114's avatar nd-02110114
Browse files

🐛 fix bug for data type

parent 97f5d2df
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -839,17 +839,17 @@ class TorchModel(Model):
    inputs = [
        x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs
    ]
    inputs = [torch.as_tensor(x, device=self.device) for x in inputs]
    inputs = [torch.as_tensor(x, device=self.device).float() for x in inputs]
    if labels is not None:
      labels = [
          x.astype(np.float32) if x.dtype == np.float64 else x for x in labels
      ]
      labels = [torch.as_tensor(x, device=self.device) for x in labels]
      labels = [torch.as_tensor(x, device=self.device).float() for x in labels]
    if weights is not None:
      weights = [
          x.astype(np.float32) if x.dtype == np.float64 else x for x in weights
      ]
      weights = [torch.as_tensor(x, device=self.device) for x in weights]
      weights = [torch.as_tensor(x, device=self.device).float() for x in weights]

    return (inputs, labels, weights)