Unverified Commit 9664aeab authored by Nathan Frey's avatar Nathan Frey Committed by GitHub
Browse files

Merge pull request #2450 from ncfrey/atomicconv-save

Add save reload to atomicconv
parents 5061825a e3d2c752
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ except:
  from collections import Sequence as SequenceCollection
from typing import Sequence, Union
from deepchem.utils.typing import KerasActivationFn, LossFn, OneOrMany
from deepchem.utils.data_utils import load_from_disk, save_to_disk

logger = logging.getLogger(__name__)

@@ -298,3 +299,11 @@ class AtomicConvModel(KerasModel):
        ]
        y_b = np.reshape(y_b, newshape=(batch_size, 1))
        yield (inputs, [y_b], [w_b])

  def save(self):
    """Saves model to disk using joblib."""
    save_to_disk(self.model, self.get_model_filename(self.model_dir))

  def reload(self):
    """Loads model from joblib file on disk."""
    self.model = load_from_disk(self.get_model_filename(self.model_dir))
+67 −0
Original line number Diff line number Diff line
@@ -282,6 +282,73 @@ def test_robust_multitask_classification_reload():
  assert scores[classification_metric.name] > .9


def test_atomic_conv_model_reload():
  from deepchem.models.atomic_conv import AtomicConvModel
  from deepchem.data import NumpyDataset
  model_dir = tempfile.mkdtemp()
  batch_size = 1
  N_atoms = 5

  acm = AtomicConvModel(
      n_tasks=1,
      batch_size=batch_size,
      layer_sizes=[
          1,
      ],
      frag1_num_atoms=5,
      frag2_num_atoms=5,
      complex_num_atoms=10,
      model_dir=model_dir)

  features = []
  frag1_coords = np.random.rand(N_atoms, 3)
  frag1_nbr_list = {0: [], 1: [], 2: [], 3: [], 4: []}
  frag1_z = np.random.randint(10, size=(N_atoms))
  frag2_coords = np.random.rand(N_atoms, 3)
  frag2_nbr_list = {0: [], 1: [], 2: [], 3: [], 4: []}
  frag2_z = np.random.randint(10, size=(N_atoms))
  system_coords = np.random.rand(2 * N_atoms, 3)
  system_nbr_list = {
      0: [],
      1: [],
      2: [],
      3: [],
      4: [],
      5: [],
      6: [],
      7: [],
      8: [],
      9: []
  }
  system_z = np.random.randint(10, size=(2 * N_atoms))

  features.append(
      (frag1_coords, frag1_nbr_list, frag1_z, frag2_coords, frag2_nbr_list,
       frag2_z, system_coords, system_nbr_list, system_z))
  features = np.asarray(features)
  labels = np.random.rand(batch_size)
  dataset = NumpyDataset(features, labels)

  acm.fit(dataset, nb_epoch=1)

  reloaded_model = AtomicConvModel(
      n_tasks=1,
      batch_size=batch_size,
      layer_sizes=[
          1,
      ],
      frag1_num_atoms=5,
      frag2_num_atoms=5,
      complex_num_atoms=10,
      model_dir=model_dir)
  reloaded_model.restore()

  # Check predictions match on random sample
  origpred = acm.predict(dataset)
  reloadpred = reloaded_model.predict(dataset)
  assert np.all(origpred == reloadpred)


def test_normalizing_flow_model_reload():
  """Test that NormalizingFlowModel can be reloaded correctly."""
  from deepchem.models.normalizing_flows import NormalizingFlow, NormalizingFlowModel