Commit a2ba234d authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixed bugs in undo_transform_outputs

parent 00eeac13
Loading
Loading
Loading
Loading
+1 −6
Original line number Diff line number Diff line
@@ -49,15 +49,10 @@ def undo_normalization(y_orig, y_pred):
  """Undo the applied normalization transform."""
  old_mean = np.mean(y_orig)
  old_std = np.std(y_orig)
  return y_orig * old_std + old_mean
  return y_pred * old_std + old_mean

def undo_transform_outputs(y_raw, y_pred, output_transforms):
  """Undo transforms on y_pred, W_pred."""
  print "undo_transform_outputs()"
  print "output_transforms"
  print output_transforms
  print "y_raw"
  print y_raw
  if output_transforms == ["log"]:
    return np.exp(y_pred)
  elif output_transforms == ["normalize"]:
+0 −0

Empty file added.

+26 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ __license__ = "LGPL"
import numpy as np
import unittest
from deep_chem.utils.preprocess import balance_positives
from deep_chem.utils.preprocess import undo_transform_outputs

def ensure_balanced(y, W):
  """Helper function that ensures postives and negatives are balanced."""
@@ -39,3 +40,28 @@ class TestPreprocess(unittest.TestCase):
          pos_weight += Wbal[sample_ind, target_ind]
      assert np.isclose(pos_weight, neg_weight)

  def test_undo_transform_outputs(self):
    # Test undo-log
    y_raw = np.ones(10)
    y_pred = np.log(y_raw)
    output_transforms = ["log"]
    assert np.array_equal(y_raw, undo_transform_outputs(y_raw, y_pred, output_transforms))

    # Test undo-normalization
    y_raw = np.random.randint(0, 10, size=(10,))
    mean = np.mean(y_raw)
    std = np.std(y_raw)
    y_pred = (y_raw-mean)/std
    output_transforms = ["normalize"]
    y_ret = undo_transform_outputs(y_raw, y_pred, output_transforms)
    assert np.allclose(y_raw, y_ret)
    
    # Test undo log-normalization
    y_raw = np.random.randint(1, 10, size=(10,))
    y_pred = np.log(y_raw)
    mean = np.mean(y_pred)
    std = np.std(y_pred)
    y_pred = (y_pred - mean)/std
    output_transforms = ["log", "normalize"]
    y_ret = undo_transform_outputs(y_raw, y_pred, output_transforms)
    assert np.allclose(y_raw, y_ret)