Commit 7b173d98 authored by Nathan Frey's avatar Nathan Frey
Browse files

Consolidate loop

parent ec4516bd
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -549,7 +549,7 @@ class JsonLoader(DataLoader):
        ids = ids[valid_inds]

        if len(self.tasks) > 0:
          # Featurize task results iff they exist.
          # Featurize task results if they exist.
          y, w = _convert_df_to_numpy(shard, self.tasks)

          if self.label_field:
@@ -611,16 +611,16 @@ class JsonLoader(DataLoader):
    """

    features = []
    valid_inds = []
    field = self.feature_field
    data = shard[field].tolist()
    for idx, datapoint in enumerate(data):
      features.append(featurizer.featurize([datapoint]))

    valid_inds = np.array(
        [1 if elt.size > 0 else 0 for elt in features], dtype=bool)
    features = [
        elt for (is_valid, elt) in zip(valid_inds, features) if is_valid
    ]
    for idx, datapoint in enumerate(data):
      feat = featurizer.featurize([datapoint])
      is_valid = True if feat.size > 0 else False
      valid_inds.append(is_valid)
      if is_valid:
        features.append(feat)

    return np.squeeze(np.array(features), axis=1), valid_inds

+2 −1
Original line number Diff line number Diff line
@@ -22,7 +22,8 @@ class TestJsonLoader(unittest.TestCase):
    self.current_dir = os.path.dirname(os.path.abspath(__file__))

  def test_json_loader(self):
    input_file = os.path.join(self.current_dir, 'perov_test.json')
    input_file = os.path.join(self.current_dir,
                              'inorganic_crystal_sample_data.json')
    featurizer = SineCoulombMatrix(max_atoms=5)
    loader = JsonLoader(
        tasks=['e_form'],