Commit 0d274fee authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Added ability to save keras models

parent 241c90a3
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -50,7 +50,6 @@ def fit_singletask_mlp(train_data, task_types, **training_params):
    print "Target %s" % target
    (train_ids, X_train, y_train, W_train) = train_data[target]
    print "%d compounds in Train" % len(train_ids)
    print "%d compounds in Test" % len(test)
    models[target] = train_multitask_model(X_train, y_train, W_train,
        {target: task_types[target]}, **training_params)
  return models
+1 −1
Original line number Diff line number Diff line
@@ -145,7 +145,7 @@ def parse_args(input_args=None):
  group.add_argument("--paths", nargs="+", required=1,
                      help="Paths to input datasets.")
  group.add_argument("--modeltype", required=1,
                      choices=["sklearn", "keras"],
                      choices=["sklearn", "keras-graph", "keras-sequential"],
                      help="Type of model to load.")
  # TODO(rbharath): This argument seems a bit extraneous. Is it really
  # necessary?
+3 −2
Original line number Diff line number Diff line
@@ -74,7 +74,7 @@ def model_predictions(X, model, n_targets, task_types, modeltype="sklearn"):
      raise ValueError("Tensorial datatype must be of shape (n_samples, N, N, N, n_channels).")
    (n_samples, axis_length, _, _, n_channels) = np.shape(X)
    X = np.reshape(X, (n_samples, axis_length, n_channels, axis_length, axis_length))
  if modeltype == "keras_multitask":
  if modeltype == "keras-graph":
    predictions = model.predict({"input": X})
    ypreds = []
    for index in range(n_targets):
@@ -86,10 +86,11 @@ def model_predictions(X, model, n_targets, task_types, modeltype="sklearn"):
      ypreds = model.predict_proba(X)
    elif task_type == "regression":
      ypreds = model.predict(X)
  elif modeltype == "keras":
  elif modeltype == "keras-sequential":
    ypreds = model.predict(X)
  else:
    raise ValueError("Improper modeltype.")
  ypreds = np.squeeze(ypreds)
  ypreds = np.reshape(ypreds, (len(ypreds), n_targets))
  return ypreds

+52 −10
Original line number Diff line number Diff line
"""
Utility functions to save models.
Utility functions to save keras/sklearn models.
"""
import os
import gzip
import cPickle as pickle
from keras.models import model_from_json
from sklearn.externals import joblib

# TODO(rbharath): This implementation only supports saving single models. Make
# some way to save metadata in addition to the actual model file.
def save_model(model, modeltype, filename):
def save_model(models, modeltype, filename):
  """Dispatcher function for saving."""
  if modeltype == "sklearn":
    save_sklearn_model(model, filename)
  elif modeltype == "keras":
    save_keras_model(model, filename)
    save_sklearn_model(models, filename)
  elif "keras" in modeltype:
    save_keras_model(models, filename)
  else:
    raise ValueError("Unsupported modeltype.")

@@ -18,16 +20,56 @@ def load_model(modeltype, filename):
  """Dispatcher function for loading."""
  if modeltype == "sklearn":
    return load_sklearn_model(filename)
  elif modeltype == "keras":
  elif "keras" in modeltype:
    return load_keras_model(filename)
  else:
    raise ValueError("Unsupported modeltype.")

def save_sklearn_model(model, filename):
def save_sklearn_model(models, filename):
  """Saves sklearn model to disk using joblib."""
  joblib.dump(model, filename)
  joblib.dump(models, filename)

def load_sklearn_model(filename):
  """Loads sklearn model from file on disk."""
  return joblib.load(filename)
  
def save_keras_model(models, filename):
  """Saves keras models to disk."""
  filename, _ = os.path.splitext(filename)
  pkl_gz_filename = "%s.%s" % (filename, "pkl.gz")
  with gzip.open(pkl_gz_filename, "wb") as f:
    pickle.dump(models.keys(), f)
  for target in models:
    model = models[target]
    # Note that keras requires the model architecture and weights to be stored
    # separately. A json file is generated that specifies the model architecture.
    # The weights will be stored in an h5 file. The pkl.gz file with store the
    # target name.
    json_filename = "%s-%s.%s" % (filename, target, "json")
    h5_filename = "%s-%s.%s" % (filename, target, "h5")
    # Save architecture
    json_string = model.to_json()
    with open(json_filename, "wb") as f:
      f.write(json_string)
    model.save_weights(h5_filename)

def load_keras_model(filename):
  """Loads keras model from disk.

  Assumes that filename.json and filename.h5 respectively contain the model
  architecture and weights.
  """
  filename, _ = os.path.splitext(filename)
  pkl_gz_filename = "%s.%s" % (filename, "pkl.gz")
  with gzip.open(pkl_gz_filename) as f:
    targets = pickle.load(f)
  models = {}
  for target in targets:
    json_filename = filename + ".json"
    h5_filename = filename + ".h5"
  
    with open(json_filename) as f:
      model = model_from_json(f.read())
    model.load_weights(h5_filename)
    models[target] = model
  return models