Commit d1dd37e6 authored by peastman's avatar peastman
Browse files

Merge branch 'master' into conv

parents 1c4e9ba6 3467db1e
Loading
Loading
Loading
Loading
+65 −38
Original line number Diff line number Diff line
@@ -947,6 +947,9 @@ class DiskDataset(Dataset):

    logger.info("Loading dataset from disk.")
    self.tasks, self.metadata_df = self.load_metadata()
    self._cached_shards = None
    self._memory_cache_size = 20 * (1 << 20)  # 20 MB
    self._cache_used = 0

  @staticmethod
  def create_dataset(shard_generator, data_dir=None, tasks=[]):
@@ -1021,25 +1024,25 @@ class DiskDataset(Dataset):
                         w=None,
                         ids=None):
    if X is not None:
      out_X = "%s-X.joblib" % basename
      out_X = "%s-X.npy" % basename
      save_to_disk(X, os.path.join(data_dir, out_X))
    else:
      out_X = None

    if y is not None:
      out_y = "%s-y.joblib" % basename
      out_y = "%s-y.npy" % basename
      save_to_disk(y, os.path.join(data_dir, out_y))
    else:
      out_y = None

    if w is not None:
      out_w = "%s-w.joblib" % basename
      out_w = "%s-w.npy" % basename
      save_to_disk(w, os.path.join(data_dir, out_w))
    else:
      out_w = None

    if ids is not None:
      out_ids = "%s-ids.joblib" % basename
      out_ids = "%s-ids.npy" % basename
      save_to_disk(ids, os.path.join(data_dir, out_ids))
    else:
      out_ids = None
@@ -1050,6 +1053,7 @@ class DiskDataset(Dataset):
  def save_to_disk(self):
    """Save dataset to disk."""
    save_metadata(self.tasks, self.metadata_df, self.data_dir)
    self._cached_shards = None

  def move(self, new_data_dir):
    """Moves dataset to new directory."""
@@ -1143,32 +1147,7 @@ class DiskDataset(Dataset):
    generator defined by this function returns the data from a particular shard.
    The order of shards returned is guaranteed to remain fixed.
    """

    def iterate(dataset):
      for _, row in dataset.metadata_df.iterrows():
        X = np.array(load_from_disk(os.path.join(dataset.data_dir, row['X'])))
        ids = np.array(
            load_from_disk(os.path.join(dataset.data_dir, row['ids'])),
            dtype=object)
        # These columns may be missing is the dataset is unlabelled.
        if row['y'] is not None:
          y = np.array(load_from_disk(os.path.join(dataset.data_dir, row['y'])))
        else:
          y = None
        if row['w'] is not None:
          w_filename = os.path.join(dataset.data_dir, row['w'])
          if os.path.exists(w_filename):
            w = np.array(load_from_disk(w_filename))
          else:
            if len(y.shape) == 1:
              w = np.ones(y.shape[0], np.float32)
            else:
              w = np.ones((y.shape[0], 1), np.float32)
        else:
          w = None
        yield (X, y, w, ids)

    return iterate(self)
    return (self.get_shard(i) for i in range(self.get_number_shards()))

  def iterbatches(self,
                  batch_size=None,
@@ -1607,6 +1586,24 @@ class DiskDataset(Dataset):

  def get_shard(self, i):
    """Retrieves data for the i-th shard from disk."""

    class Shard(object):

      def __init__(self, X, y, w, ids):
        self.X = X
        self.y = y
        self.w = w
        self.ids = ids

    # See if we have a cached copy of this shard.
    if self._cached_shards is None:
      self._cached_shards = [None] * self.get_number_shards()
      self._cache_used = 0
    if self._cached_shards[i] is not None:
      shard = self._cached_shards[i]
      return (shard.X, shard.y, shard.w, shard.ids)

    # We don't, so load it from disk.
    row = self.metadata_df.iloc[i]
    X = np.array(load_from_disk(os.path.join(self.data_dir, row['X'])))

@@ -1630,7 +1627,24 @@ class DiskDataset(Dataset):

    ids = np.array(
        load_from_disk(os.path.join(self.data_dir, row['ids'])), dtype=object)
    return (X, y, w, ids)

    # Try to cache this shard for later use.  Since the normal usage pattern is
    # a series of passes through the whole dataset, there's no point doing
    # anything fancy.  It never makes sense to evict another shard from the
    # cache to make room for this one, because we'll probably want that other
    # shard again before the next time we want this one.  So just cache as many
    # as we can and then stop.

    shard = Shard(X, y, w, ids)
    shard_size = X.nbytes + ids.nbytes
    if y is not None:
      shard_size += y.nbytes
    if w is not None:
      shard_size += w.nbytes
    if self._cache_used + shard_size < self._memory_cache_size:
      self._cached_shards[i] = shard
      self._cache_used += shard_size
    return (shard.X, shard.y, shard.w, shard.ids)

  def add_shard(self, X, y, w, ids):
    """Adds a data shard."""
@@ -1649,6 +1663,7 @@ class DiskDataset(Dataset):
    basename = "shard-%d" % shard_num
    tasks = self.get_task_names()
    DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w, ids)
    self._cached_shards = None

  def select(self, indices, select_dir=None):
    """Creates a new dataset from a selection of indices from self.
@@ -1759,6 +1774,18 @@ class DiskDataset(Dataset):
    else:
      return np.concatenate(ws)

  @property
  def memory_cache_size(self):
    """Get the size of the memory cache for this dataset, measured in bytes."""
    return self._memory_cache_size

  @memory_cache_size.setter
  def memory_cache_size(self, size):
    """Get the size of the memory cache for this dataset, measured in bytes."""
    self._memory_cache_size = size
    if self._cache_used > size:
      self._cached_shards = None

  def __len__(self):
    """
    Finds number of elements in dataset.
+2 −2
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ from sklearn.metrics import mean_absolute_error
from sklearn.metrics import precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from sklearn.metrics import jaccard_similarity_score
from sklearn.metrics import jaccard_score
from sklearn.metrics import f1_score
from scipy.stats import pearsonr

@@ -91,7 +91,7 @@ def jaccard_index(y, y_pred):
      y: ground truth array
      y_pred: predicted array
    """
  return jaccard_similarity_score(y, y_pred)
  return jaccard_score(y, y_pred)


def pixel_error(y, y_pred):
+2 −0
Original line number Diff line number Diff line
@@ -153,6 +153,8 @@ class Splitter(object):
    else:
      valid_dataset = None
    test_dataset = dataset.select(test_inds, test_dir)
    if isinstance(train_dataset, DiskDataset):
      train_dataset.memory_cache_size = 40 * (1 << 20)  # 40 MB

    return train_dataset, valid_dataset, test_dataset

+4 −5
Original line number Diff line number Diff line
@@ -62,7 +62,6 @@ class Transformer(object):
               transform_w=False,
               dataset=None):
    """Initializes transformation based on dataset statistics."""
    self.dataset = dataset
    self.transform_X = transform_X
    self.transform_y = transform_y
    self.transform_w = transform_w
@@ -482,12 +481,12 @@ class BalancingTransformer(Transformer):
    assert transform_w

    # Compute weighting factors from dataset.
    y = self.dataset.y
    w = self.dataset.w
    y = dataset.y
    w = dataset.w
    # Ensure dataset is binary
    np.testing.assert_allclose(sorted(np.unique(y)), np.array([0., 1.]))
    weights = []
    for ind, task in enumerate(self.dataset.get_task_names()):
    for ind, task in enumerate(dataset.get_task_names()):
      task_w = w[:, ind]
      task_y = y[:, ind]
      # Remove labels with zero weights
@@ -505,7 +504,7 @@ class BalancingTransformer(Transformer):
  def transform_array(self, X, y, w):
    """Transform the data in a set of (X, y, w) arrays."""
    w_balanced = np.zeros_like(w)
    for ind, task in enumerate(self.dataset.get_task_names()):
    for ind in range(y.shape[1]):
      task_y = y[:, ind]
      task_w = w[:, ind]
      zero_indices = np.logical_and(task_y == 0, task_w != 0)
+352 −0
Original line number Diff line number Diff line
"""This module adds utilities for coordinate boxes"""
import numpy as np
from scipy.spatial import ConvexHull


def intersect_interval(interval1, interval2):
  """Computes the intersection of two intervals.

  Parameters
  ----------
  interval1: tuple[int]
    Should be `(x1_min, x1_max)`
  interval2: tuple[int]
    Should be `(x2_min, x2_max)`

  Returns
  -------
  x_intersect: tuple[int]
    Should be the intersection. If the intersection is empty returns
    `(0, 0)` to represent the empty set. Otherwise is `(max(x1_min,
    x2_min), min(x1_max, x2_max))`.
  """
  x1_min, x1_max = interval1
  x2_min, x2_max = interval2
  if x1_max < x2_min:
    # If interval1 < interval2 entirely
    return (0, 0)
  elif x2_max < x1_min:
    # If interval2 < interval1 entirely
    return (0, 0)
  x_min = max(x1_min, x2_min)
  x_max = min(x1_max, x2_max)
  return (x_min, x_max)


def intersection(box1, box2):
  """Computes the intersection box of provided boxes.

  Parameters
  ----------
  box1: `CoordinateBox`
    First `CoordinateBox`
  box2: `CoordinateBox`
    Another `CoordinateBox` to intersect first one with.

  Returns
  -------
  A `CoordinateBox` containing the intersection. If the intersection is empty, returns the box with 0 bounds.
  """
  x_intersection = intersect_interval(box1.x_range, box2.x_range)
  y_intersection = intersect_interval(box1.y_range, box2.y_range)
  z_intersection = intersect_interval(box1.z_range, box2.z_range)
  return CoordinateBox(x_intersection, y_intersection, z_intersection)


def union(box1, box2):
  """Merges provided boxes to find the smallest union box. 

  This method merges the two provided boxes.

  Parameters
  ----------
  box1: `CoordinateBox`
    First box to merge in
  box2: `CoordinateBox`
    Second box to merge into this box

  Returns
  -------
  Smallest `CoordinateBox` that contains both `box1` and `box2`
  """
  x_min = min(box1.x_range[0], box2.x_range[0])
  y_min = min(box1.y_range[0], box2.y_range[0])
  z_min = min(box1.z_range[0], box2.z_range[0])
  x_max = max(box1.x_range[1], box2.x_range[1])
  y_max = max(box1.y_range[1], box2.y_range[1])
  z_max = max(box1.z_range[1], box2.z_range[1])
  return CoordinateBox((x_min, x_max), (y_min, y_max), (z_min, z_max))


def merge_overlapping_boxes(boxes, threshold=.8):
  """Merge boxes which have an overlap greater than threshold.

  Parameters
  ----------
  boxes: list[CoordinateBox]
    A list of `CoordinateBox` objects.
  threshold: float, optional (default 0.8)
    The volume fraction of the boxes that must overlap for them to be
    merged together. 
  
  Returns
  -------
  list[CoordinateBox] of merged boxes. This list will have length less
  than or equal to the length of `boxes`.
  """
  outputs = []
  for box in boxes:
    for other in boxes:
      if box == other:
        continue
      intersect_box = intersection(box, other)
      if (intersect_box.volume() >= threshold * box.volume() or
          intersect_box.volume() >= threshold * other.volume()):
        box = union(box, other)
    unique_box = True
    for output in outputs:
      if output.contains(box):
        unique_box = False
    if unique_box:
      outputs.append(box)
  return outputs


def get_face_boxes(coords, pad=5):
  """For each face of the convex hull, compute a coordinate box around it.

  The convex hull of a macromolecule will have a series of triangular
  faces. For each such triangular face, we construct a bounding box
  around this triangle. Think of this box as attempting to capture
  some binding interaction region whose exterior is controlled by the
  box. Note that this box will likely be a crude approximation, but
  the advantage of this technique is that it only uses simple geometry
  to provide some basic biological insight into the molecule at hand.

  The `pad` parameter is used to control the amount of padding around
  the face to be used for the coordinate box.

  Parameters
  ----------
  coords: np.ndarray
    Of shape `(N, 3)`. The coordinates of a molecule.
  pad: float, optional (default 5)
    The number of angstroms to pad.

  Examples
  --------
  >>> coords = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
  >>> boxes = get_face_boxes(coords, pad=5)
  """
  hull = ConvexHull(coords)
  boxes = []
  # Each triangle in the simplices is a set of 3 atoms from
  # coordinates which forms the vertices of an exterior triangle on
  # the convex hull of the macromolecule.
  for triangle in hull.simplices:
    # Points is the set of atom coordinates that make up this
    # triangular face on the convex hull
    points = np.array(
        [coords[triangle[0]], coords[triangle[1]], coords[triangle[2]]])
    # Let's extract x/y/z coords for this face
    x_coords = points[:, 0]
    y_coords = points[:, 1]
    z_coords = points[:, 2]

    # Let's compute min/max points
    x_min, x_max = np.amin(x_coords), np.amax(x_coords)
    x_min, x_max = int(np.floor(x_min)) - pad, int(np.ceil(x_max)) + pad
    x_bounds = (x_min, x_max)

    y_min, y_max = np.amin(points[:, 1]), np.amax(points[:, 1])
    y_min, y_max = int(np.floor(y_min)) - pad, int(np.ceil(y_max)) + pad
    y_bounds = (y_min, y_max)
    z_min, z_max = np.amin(points[:, 2]), np.amax(points[:, 2])
    z_min, z_max = int(np.floor(z_min)) - pad, int(np.ceil(z_max)) + pad
    z_bounds = (z_min, z_max)
    box = CoordinateBox(x_bounds, y_bounds, z_bounds)
    boxes.append(box)
  return boxes


class CoordinateBox(object):
  """A coordinate box that represents a block in space.

  Molecular complexes are typically represented with atoms as
  coordinate points. Each complex is naturally associated with a
  number of different box regions. For example, the bounding box is a
  box that contains all atoms in the molecular complex. A binding
  pocket box is a box that focuses in on a binding region of a protein
  to a ligand. A interface box is the region in which two proteins
  have a bulk interaction.

  The `CoordinateBox` class is designed to represent such regions of
  space. It consists of the coordinates of the box, and the collection
  of atoms that live in this box alongside their coordinates.
  """

  def __init__(self, x_range, y_range, z_range):
    """Initialize this box.

    Parameters
    ----------
    x_range: tuple
      A tuple of `(x_min, x_max)` with max and min x-coordinates.
    y_range: tuple
      A tuple of `(y_min, y_max)` with max and min y-coordinates.
    z_range: tuple
      A tuple of `(z_min, z_max)` with max and min z-coordinates.

    Raises
    ------
    `ValueError` if this interval is malformed
    """
    if not isinstance(x_range, tuple) or not len(x_range) == 2:
      raise ValueError("x_range must be a tuple of length 2")
    else:
      x_min, x_max = x_range
      if not x_min <= x_max:
        raise ValueError("x minimum must be <= x maximum")
    if not isinstance(y_range, tuple) or not len(y_range) == 2:
      raise ValueError("y_range must be a tuple of length 2")
    else:
      y_min, y_max = y_range
      if not y_min <= y_max:
        raise ValueError("y minimum must be <= y maximum")
    if not isinstance(z_range, tuple) or not len(z_range) == 2:
      raise ValueError("z_range must be a tuple of length 2")
    else:
      z_min, z_max = z_range
      if not z_min <= z_max:
        raise ValueError("z minimum must be <= z maximum")
    self.x_range = x_range
    self.y_range = y_range
    self.z_range = z_range

  def __repr__(self):
    """Create a string representation of this box"""
    x_str = str(self.x_range)
    y_str = str(self.y_range)
    z_str = str(self.z_range)
    return "Box[x_bounds=%s, y_bounds=%s, z_bounds=%s]" % (x_str, y_str, z_str)

  def __str__(self):
    """Create a string representation of this box."""
    return self.__repr__()

  def __contains__(self, point):
    """Check whether a point is in this box.

    Parameters
    ----------
    point: 3-tuple or list of length 3 or  np.ndarray of shape `(3,)`
      The `(x, y, z)` coordinates of a point in space.
    """
    (x_min, x_max) = self.x_range
    (y_min, y_max) = self.y_range
    (z_min, z_max) = self.z_range
    x_cont = (x_min <= point[0] and point[0] <= x_max)
    y_cont = (y_min <= point[1] and point[1] <= y_max)
    z_cont = (z_min <= point[2] and point[2] <= z_max)
    return x_cont and y_cont and z_cont

  def __eq__(self, other):
    """Compare two boxes to see if they're equal.

    Parameters
    ----------
    other: `CoordinateBox`
      Compare this coordinate box to the other one.

    Returns
    -------
    bool that's `True` if all bounds match.

    Raises
    ------
    `ValueError` if attempting to compare to something that isn't a
    `CoordinateBox`.
    """
    if not isinstance(other, CoordinateBox):
      raise ValueError("Can only compare to another box.")
    return (self.x_range == other.x_range and self.y_range == other.y_range and
            self.z_range == other.z_range)

  def __hash__(self):
    """Implement hashing function for this box.

    Uses the default `hash` on `self.x_range, self.y_range,
    self.z_range`.

    Returns
    -------
    Unique integeer
    """
    return hash((self.x_range, self.y_range, self.z_range))

  def center(self):
    """Computes the center of this box.

    Returns
    -------
    `(x, y, z)` the coordinates of the center of the box.

    Examples
    --------
    >>> box = CoordinateBox((0, 1), (0, 1), (0, 1))
    >>> box.center()
    (0.5, 0.5, 0.5)
    """
    x_min, x_max = self.x_range
    y_min, y_max = self.y_range
    z_min, z_max = self.z_range
    return (x_min + (x_max - x_min) / 2, y_min + (y_max - y_min) / 2,
            z_min + (z_max - z_min) / 2)

  def volume(self):
    """Computes and returns the volume of this box.

    Returns
    -------
    float, the volume of this box. Can be 0 if box is empty

    Examples
    --------
    >>> box = CoordinateBox((0, 1), (0, 1), (0, 1))
    >>> box.volume()
    1
    """
    x_min, x_max = self.x_range
    y_min, y_max = self.y_range
    z_min, z_max = self.z_range
    return (x_max - x_min) * (y_max - y_min) * (z_max - z_min)

  def contains(self, other):
    """Test whether this box contains another.

    This method checks whether `other` is contained in this box.

    Parameters
    ----------
    other: `CoordinateBox`
      The box to check is contained in this box.

    Returns
    -------
    bool, `True` if `other` is contained in this box.

    Raises
    ------
    `ValueError` if `not isinstance(other, CoordinateBox)`.
    """
    if not isinstance(other, CoordinateBox):
      raise ValueError("other must be a CoordinateBox")
    other_x_min, other_x_max = other.x_range
    other_y_min, other_y_max = other.y_range
    other_z_min, other_z_max = other.z_range
    self_x_min, self_x_max = self.x_range
    self_y_min, self_y_max = self.y_range
    self_z_min, self_z_max = self.z_range
    return (self_x_min <= other_x_min and other_x_max <= self_x_max and
            self_y_min <= other_y_min and other_y_max <= self_y_max and
            self_z_min <= other_z_min and other_z_max <= self_z_max)
Loading