Commit 75d6a018 authored by miaecle's avatar miaecle
Browse files

fix bug in lr

parent fc37366f
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -211,9 +211,8 @@ class TensorflowLogisticRegression(TensorflowGraphModel):
        # transfer 2D prediction tensor to 2D x n_classes(=2) 
        complimentary = np.ones(np.shape(batch_output))
        complimentary = complimentary - batch_output
        batch_output = np.squeeze(np.stack(arrays = [complimentary,
                                                     batch_output],
                                            axis = 2))
        batch_output = np.concatenate([complimentary, batch_output],
                                            axis = batch_output.ndim-1)
        # reshape to batch_size x n_tasks x ...
        if batch_output.ndim == 3:
          batch_output = batch_output.transpose((1, 0, 2))