Unverified Commit 34aa0b4a authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1894 from TaranSinghania/master

Added Data Augmentation Support
parents daa2a72c bda1d7c6
Loading
Loading
Loading
Loading
+39 −0
Original line number Diff line number Diff line
@@ -587,6 +587,35 @@ class TestTransformers(unittest.TestCase):
    check_blur = scipy.ndimage.gaussian_filter(self.d, 1.5)
    assert np.allclose(check_blur, blurred)

  def test_center_crop(self):
    # Check center crop
    dt = DataTransforms(self.d)
    x_crop = 50
    y_crop = 50
    crop = dt.center_crop(x_crop, y_crop)
    y = self.d.shape[0]
    x = self.d.shape[1]
    x_start = x // 2 - (x_crop // 2)
    y_start = y // 2 - (y_crop // 2)
    check_crop = self.d[y_start:y_start + y_crop, x_start:x_start + x_crop]
    assert np.allclose(check_crop, crop)

  def test_crop(self):
    #Check crop
    dt = DataTransforms(self.d)
    crop = dt.crop(0, 10, 0, 10)
    y = self.d.shape[0]
    x = self.d.shape[1]
    check_crop = self.d[10:y - 10, 0:x - 0]
    assert np.allclose(crop, check_crop)

  def test_convert2gray(self):
    # Check convert2gray
    dt = DataTransforms(self.d)
    gray = dt.convert2gray()
    check_gray = np.dot(self.d[..., :3], [0.2989, 0.5870, 0.1140])
    assert np.allclose(check_gray, gray)

  def test_rotation(self):
    # Check rotation
    dt = DataTransforms(self.d)
@@ -677,3 +706,13 @@ class TestTransformers(unittest.TestCase):
    # atoms. These are denoted the "parents"
    for idm, mol in enumerate(dataset.X):
      assert dataset.X[idm].get_num_atoms() == len(dataset.X[idm].parents)

  def test_median_filter(self):
    #Check median filter
    from PIL import Image, ImageFilter
    dt = DataTransforms(self.d)
    filtered = dt.median_filter(size=3)
    image = Image.fromarray(self.d)
    image = image.filter(ImageFilter.MedianFilter(size=3))
    check_filtered = np.array(image)
    assert np.allclose(check_filtered, filtered)
+79 −2
Original line number Diff line number Diff line
@@ -1306,8 +1306,15 @@ class DataTransforms(Transformer):

  def rotate(self, angle=0):
    """ Rotates the image
          Parameters:
              angle (default = 0 i.e no rotation) - Denotes angle by which the image should be rotated (in Degrees)

    Parameters
    ----------
    angle: float (default = 0 i.e no rotation)
	Denotes angle by which the image should be rotated (in Degrees)

    Returns
    ----------
    The rotated imput array
    """
    return scipy.ndimage.rotate(self.Image, angle)

@@ -1318,6 +1325,59 @@ class DataTransforms(Transformer):
    """
    return scipy.ndimage.gaussian_filter(self.Image, sigma)

  def center_crop(self, x_crop, y_crop):
    """ Crops the image from the center

    Parameters
    ----------
    x_crop: int
	the total number of pixels to remove in the horizontal direction, evenly split between the left and right sides
    y_crop: int
        the total number of pixels to remove in the vertical direction, evenly split between the top and bottom sides

    Returns
    ----------
    The center cropped input array

    """
    y = self.Image.shape[0]
    x = self.Image.shape[1]
    x_start = x // 2 - (x_crop // 2)
    y_start = y // 2 - (y_crop // 2)
    return self.Image[y_start:y_start + y_crop, x_start:x_start + x_crop]

  def crop(self, left, top, right, bottom):
    """ Crops the image and returns the specified rectangular region from an image

    Parameters
    ----------
    left: int
	the number of pixels to exclude from the left of the image
    top: int
	the number of pixels to exclude from the top of the image
    right: int
	the number of pixels to exclude from the right of the image    
    bottom: int
	the number of pixels to exclude from the bottom of the image

    Returns
    ----------
    The cropped input array
    """
    y = self.Image.shape[0]
    x = self.Image.shape[1]
    return self.Image[top:y - bottom, left:x - right]

  def convert2gray(self):
    """ Converts the image to grayscale. The coefficients correspond to the Y' component of the Y'UV color system.
    
    Returns
    ----------
    The grayscale image.

    """
    return np.dot(self.Image[..., :3], [0.2989, 0.5870, 0.1140])

  def shift(self, width, height, mode='constant', order=3):
    """Shifts the image
        Parameters:
@@ -1358,3 +1418,20 @@ class DataTransforms(Transformer):
    x[noise < (prob / 2)] = pepper
    x[noise > (1 - prob / 2)] = salt
    return x

  def median_filter(self, size):
    """ Calculates a multidimensional median filter

    Parameters
    ----------
    size: int
	The kernel size in pixels.

    Returns
    ----------
    The median filtered image.
    """
    from PIL import Image, ImageFilter
    image = Image.fromarray(self.Image)
    image = image.filter(ImageFilter.MedianFilter(size=size))
    return np.array(image)