Commit 1bfb1cc3 authored by mufeili's avatar mufeili
Browse files

Update

parent 05ae3a79
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -364,7 +364,7 @@ class TorchModel(Model):

      # Execute the loss function, accumulating the gradients.

      if len(inputs) == 1:
      if isinstance(inputs, list) and len(inputs) == 1:
        inputs = inputs[0]

      optimizer.zero_grad()
@@ -524,7 +524,7 @@ class TorchModel(Model):
      inputs, _, _ = self._prepare_batch((inputs, None, None))

      # Invoke the model.
      if len(inputs) == 1:
      if isinstance(inputs, list) and len(inputs) == 1:
        inputs = inputs[0]
      output_values = self.model(inputs)
      if isinstance(output_values, torch.Tensor):