Commit 0339c2d4 authored by evanfeinberg's avatar evanfeinberg
Browse files

added additional sharding to deal with GPU mem limit

parent d4519b34
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -78,7 +78,8 @@ class DockingDNN(Model):
  def fit_on_batch(self, X, y, w):
    # TODO(rbharath): Modify the featurization so that it matches desired shaped.
    X = shuffle_data(X)
    self.raw_model.train_on_batch(X, y)
    loss = self.raw_model.train_on_batch(X, y)
    print("Loss: %f" % loss)

  def predict_on_batch(self, X):
    if len(np.shape(X)) != 5:
+15 −1
Original line number Diff line number Diff line
@@ -47,6 +47,7 @@ def compute_y_pred(model, data_dir, csv_out, split):

  split_df = metadata_df.loc[metadata_df['split'] == split]
  nb_batch = split_df.shape[0]
  MAX_GPU_RAM = float(691007488/50)

  for i, row in split_df.iterrows():
    print("Evaluating on %s batch %d out of %d" % (split, i+1, nb_batch))
@@ -55,7 +56,20 @@ def compute_y_pred(model, data_dir, csv_out, split):
    w = load_sharded_dataset(row['w'])
    ids = load_sharded_dataset(row['ids'])

    if sys.getsizeof(X) > MAX_GPU_RAM:
      nb_block = float(sys.getsizeof(X))/MAX_GPU_RAM
      nb_sample = np.shape(X)[0]
      interval_points = np.linspace(0,nb_sample,nb_block+1).astype(int)
      for j in range(0,len(interval_points)-1):
        indices = range(interval_points[j],interval_points[j+1])
        X_batch = X[indices,:]
        y_batch = y[indices]
        w_batch = w[indices]
        y_preds.append(model.predict_on_batch(X_batch))
      y_pred = np.concatenate(y_preds)
    else:
      y_pred = model.predict_on_batch(X)

    y_pred = np.reshape(y_pred, np.shape(y))

    mini_df = pd.DataFrame(columns=column_names)
+16 −3
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ from deep_chem.utils.save import load_sharded_dataset
from deep_chem.utils.save import save_model
from deep_chem.utils.preprocess import get_task_type
import numpy as np
import sys

def get_task_names(metadata_df):
  """
@@ -38,14 +39,26 @@ def fit_model(model_name, model_params, model_dir, data_dir):

  train_metadata = metadata_df.loc[metadata_df['split'] =="train"]
  nb_batch = train_metadata.shape[0]

  MAX_GPU_RAM = float(691007488/50)
  for i, row in train_metadata.iterrows():
    print("Training on batch %d out of %d" % (i+1, nb_batch))

    X = load_sharded_dataset(row['X'])
    y = load_sharded_dataset(row['y'])
    w = load_sharded_dataset(row['w'])

    print("sys.getsizeof(X): %s" % str(sys.getsizeof(X)))
    if sys.getsizeof(X) > MAX_GPU_RAM:
      print("X exceeds available GPU memory size. Sharding.")
      nb_block = float(sys.getsizeof(X))/MAX_GPU_RAM
      nb_sample = np.shape(X)[0]
      interval_points = np.linspace(0,nb_sample,nb_block+1).astype(int)
      for j in range(0,len(interval_points)-1):
        indices = range(interval_points[j],interval_points[j+1])
        X_batch = X[indices,:]
        y_batch = y[indices]
        w_batch = w[indices]
        model.fit_on_batch(X_batch, y_batch, w_batch)
    else:
      model.fit_on_batch(X, y, w)

  save_model(model, model_name, model_dir)