Commit c821268a authored by Nathan Frey's avatar Nathan Frey
Browse files

Add tests

parent 5acb2308
Loading
Loading
Loading
Loading
+66 −11
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ import time
import sys
import logging
import warnings
from typing import List, Optional
from typing import List, Optional, Dict

from deepchem.utils.save import load_csv_files, load_json_files
from deepchem.utils.save import load_sdf_files
@@ -449,12 +449,26 @@ class JsonLoader(DataLoader):
  large json files that you don't want to manipulate directly in
  memory.

  It is meant to load JSON files formatted as "records" in line
  delimited format, which allows for sharding.
  ``list like [{column -> value}, ... , {column -> value}]``.

  Examples
  --------
  >> import pandas as pd
  >> df = pd.DataFrame(some_data)
  >> df.columns.tolist()
  .. ['formula', structure', 'task']
  >> df.to_json('file.json', orient='records', lines=True)
  >> loader = JsonLoader(['task'], {'structure': dict}, 'formula')
  >> dataset = loader.create_dataset('file.json')
  
  """

  def __init__(self,
               tasks: List[str],
               smiles_field: Optional[str] = None,
               id_field: Optional[str] = None,
               json_fields: Dict[str, type],
               id_field: str = None,
               featurizer: Optional[Featurizer] = None,
               log_every_n: int = 1000):
    """Initializes JsonLoader.
@@ -463,10 +477,11 @@ class JsonLoader(DataLoader):
    ----------
    tasks : List[str]
      List of task names
    smiles_field : str, optional
      Name of field that holds smiles string 
    id_field : str, optional
      Name of field that holds sample identifier
    json_fields : Dict[str, type]
      column names and dtypes in dataframe containing data to be featurized
      e.g. {"structure": dict, "composition": str}
    id_field : str, default None
      Column for identifying samples.
    featurizer : dc.feat.Featurizer, optional
      Featurizer to use to process data
    log_every_n : int, optional
@@ -477,9 +492,9 @@ class JsonLoader(DataLoader):
    if not isinstance(tasks, list):
      raise ValueError("Tasks must be a list.")
    self.tasks = tasks
    self.smiles_field = smiles_field
    self.json_fields = json_fields
    if id_field is None:
      self.id_field = smiles_field
      self.id_field = next(iter(json_fields))
    else:
      self.id_field = id_field

@@ -495,12 +510,52 @@ class JsonLoader(DataLoader):

  def _featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
    return _featurize_smiles_df(
    return self._featurize_df(
        shard,
        self.featurizer,
        field=self.smiles_field,
        json_fields=self.json_fields,
        log_every_n=self.log_every_n)

  def _featurize_df(self,
                    shard,
                    featurizer: Featurizer,
                    json_fields: Dict[str, type],
                    log_every_n: int = 1000):
    """Featurize individual materials in dataframe.

    Helper that given a featurizer that operates on individual
    inorganic crystal structures, computes & adds features for
    that compound to the features dataframe.

    Parameters
    ----------
    shard: pd.DataFrame
      DataFrame that holds pymatgen.Structure dict or 
      pymatgen.Composition str
    featurizer: CrystalFeaturizer
      A crystal featurizer object
    json_fields : Dict[str, type]
      column names and dtypes in dataframe containing data to be featurized
      e.g. {"structure": dict, "composition": str}
    log_every_n: int, optional (default 1000)
      Emit a logging statement every `log_every_n` rows.

    """

    features = []
    field = next(iter(json_fields))
    data = shard[field].tolist()
    for idx, datapoint in enumerate(data):
      features.append(featurizer.featurize([datapoint]))

    valid_inds = np.array(
        [1 if elt.size > 0 else 0 for elt in features], dtype=bool)
    features = [
        elt for (is_valid, elt) in zip(valid_inds, features) if is_valid
    ]

    return np.squeeze(np.array(features), axis=1), valid_inds


class SDFLoader(DataLoader):
  """
+5 −0
Original line number Diff line number Diff line
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[3.9545311068,0.0,0.0],[0.0,3.9545311068,0.0],[0.0,0.0,3.9545311068]],"a":3.9545311068,"b":3.9545311068,"c":3.9545311068,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":61.8422081649},"sites":[{"species":[{"element":"Rh","occu":1}],"abc":[0.0,0.0,0.0],"xyz":[0.0,0.0,0.0],"label":"Rh","properties":{}},{"species":[{"element":"Te","occu":1}],"abc":[0.5,0.5,0.5],"xyz":[1.9772655534,1.9772655534,1.9772655534],"label":"Te","properties":{}},{"species":[{"element":"N","occu":1}],"abc":[0.5,0.0,0.5],"xyz":[1.9772655534,0.0,1.9772655534],"label":"N","properties":{}},{"species":[{"element":"N","occu":1}],"abc":[0.5,0.5,0.0],"xyz":[1.9772655534,1.9772655534,0.0],"label":"N","properties":{}},{"species":[{"element":"N","occu":1}],"abc":[0.0,0.5,0.5],"xyz":[0.0,1.9772655534,1.9772655534],"label":"N","properties":{}}]},"e_form":2.16,"formula":"TeRhN3"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.2894318978,0.0,0.0],[0.0,4.2894318978,0.0],[0.0,0.0,4.2894318978]],"a":4.2894318978,"b":4.2894318978,"c":4.2894318978,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":78.9222269246},"sites":[{"species":[{"element":"Hf","occu":1}],"abc":[0.5922504528,0.0,0.0],"xyz":[2.5404179838,0.0,0.0],"label":"Hf","properties":{}},{"species":[{"element":"Te","occu":1}],"abc":[0.2378848852,0.5,0.5],"xyz":[1.0203910146,2.1447159489,2.1447159489],"label":"Te","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.5012320713,0.0,0.5],"xyz":[2.1500008347,0.0,2.1447159489],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.5012320713,0.5,0.0],"xyz":[2.1500008347,2.1447159489,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.7980811547,0.5,0.5],"xyz":[3.4233147622,2.1447159489,2.1447159489],"label":"O","properties":{}}]},"e_form":1.52,"formula":"HfTeO3"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.2926387638,0.0,0.0],[0.0,4.2926387638,0.0],[0.0,0.0,4.2926387638]],"a":4.2926387638,"b":4.2926387638,"c":4.2926387638,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":79.0993708544},"sites":[{"species":[{"element":"Re","occu":1}],"abc":[0.1416166515,0.0,0.0],"xyz":[0.6079091278,0.0,0.0],"label":"Re","properties":{}},{"species":[{"element":"As","occu":1}],"abc":[0.5093856748,0.5,0.5],"xyz":[2.1866086932,2.1463193819,2.1463193819],"label":"As","properties":{}},{"species":[{"element":"F","occu":1}],"abc":[0.5316865005,0.0,0.5],"xyz":[2.2823380822,0.0,2.1463193819],"label":"F","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.3074869463,0.5,0.0],"xyz":[1.319930385,2.1463193819,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.927582418,0.5,0.5],"xyz":[3.9817762444,2.1463193819,2.1463193819],"label":"O","properties":{}}]},"e_form":1.48,"formula":"ReAsO2F"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.1837305646,0.0,0.0],[0.0,4.1837305646,0.0],[0.0,0.0,4.1837305646]],"a":4.1837305646,"b":4.1837305646,"c":4.1837305646,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":73.2303523231},"sites":[{"species":[{"element":"W","occu":1}],"abc":[0.676648156,0.0,0.0],"xyz":[2.8309135716,0.0,0.0],"label":"W","properties":{}},{"species":[{"element":"Re","occu":1}],"abc":[0.6351628832,0.5,0.5],"xyz":[2.6573503678,2.0918652823,2.0918652823],"label":"Re","properties":{}},{"species":[{"element":"S","occu":1}],"abc":[0.3728524724,0.0,0.5],"xyz":[1.5599142849,0.0,2.0918652823],"label":"S","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.7238489421,0.5,0.0],"xyz":[3.0283889434,2.0918652823,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.0978520248,0.5,0.5],"xyz":[0.4093865068,2.0918652823,2.0918652823],"label":"O","properties":{}}]},"e_form":1.24,"formula":"ReWSO2"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.2811442539,0.0,0.0],[0.0,4.2811442539,0.0],[0.0,0.0,4.2811442539]],"a":4.2811442539,"b":4.2811442539,"c":4.2811442539,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":78.4656515166},"sites":[{"species":[{"element":"Bi","occu":1}],"abc":[0.0012121467,0.0,0.0],"xyz":[0.0051893747,0.0,0.0],"label":"Bi","properties":{}},{"species":[{"element":"Hf","occu":1}],"abc":[0.5074940801,0.5,0.5],"xyz":[2.1726553651,2.140572127,2.140572127],"label":"Hf","properties":{}},{"species":[{"element":"F","occu":1}],"abc":[0.4990106707,0.0,0.5],"xyz":[2.1363366656,0.0,2.140572127],"label":"F","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.499996373,0.5,0.0],"xyz":[2.1405565992,2.140572127,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.002611863,0.5,0.5],"xyz":[0.0111817624,2.140572127,2.140572127],"label":"O","properties":{}}]},"e_form":0.62,"formula":"HfBiO2F"}
 No newline at end of file
+37 −0
Original line number Diff line number Diff line
"""
Tests for JsonLoader class.
"""

import os
import unittest
import tempfile
import shutil
import numpy as np
import deepchem as dc
from deepchem.data.data_loader import JsonLoader
from deepchem.feat.materials_featurizers import SineCoulombMatrix


class TestJsonLoader(unittest.TestCase):
  """
  Test JsonLoader
  """

  def setUp(self):
    super(TestJsonLoader, self).setUp()
    self.current_dir = os.path.dirname(os.path.abspath(__file__))

  def test_json_loader(self):
    input_file = os.path.join(self.current_dir, 'perov_test.json')
    featurizer = SineCoulombMatrix(max_atoms=5)
    loader = JsonLoader(
        tasks=['e_form'],
        json_fields={"structure": dict},
        id_field='formula',
        featurizer=featurizer)
    dataset = loader.create_dataset(input_file, shard_size=1)

    a = [4625.32086965, 6585.20209678, 61.00680193, 48.72230922, 48.72230922]

    assert dataset.X.shape == (5, 1, 5)
    assert np.allclose(dataset.X[0][0], a, atol=.5)
+1 −1
Original line number Diff line number Diff line
@@ -151,7 +151,7 @@ class SineCoulombMatrix(Featurizer):

    if self.flatten:
      eigs, _ = np.linalg.eig(sine_mat)
      zeros = np.zeros((self.max_atoms,))
      zeros = np.zeros((1,self.max_atoms))
      zeros[:len(eigs)] = eigs
      features = zeros
    else:
+5 −7
Original line number Diff line number Diff line
@@ -10,9 +10,11 @@ import numpy as np
import os
import deepchem
import warnings
import logging
from typing import List, Optional
from deepchem.utils.genomics import encode_bio_sequence as encode_sequence, encode_fasta_sequence as fasta_sequence, seq_one_hot_encode as seq_one_hotencode

logger = logging.getLogger(__name__)

def log(string, verbose=True):
  """Print string if verbose."""
@@ -118,8 +120,7 @@ def load_csv_files(filenames, shard_size=None, verbose=True):


def load_json_files(filenames: List[str],
                    shard_size: Optional[int] = None,
                    verbose: bool = True):
                    shard_size: Optional[int] = None):
  """Load data as pandas dataframe.

  Parameters
@@ -128,8 +129,6 @@ def load_json_files(filenames: List[str],
    List of json filenames.
  shard_size : int, optional
    Chunksize for reading json files.
  verbose : bool (default True)
    Log json loading with shard numbers.

  Yields
  ------
@@ -149,11 +148,10 @@ def load_json_files(filenames: List[str],
    if shard_size is None:
      yield pd.read_json(filename)
    else:
      log("About to start loading json from %s" % filename, verbose)
      logger.info("About to start loading json from %s." % filename)
      for df in pd.read_json(
          filename, orient='records', chunksize=shard_size, lines=True):
        log("Loading shard %d of size %s." % (shard_num, str(shard_size)),
            verbose)
        logger.info("Loading shard %d of size %s." % (shard_num, str(shard_size)))
        df = df.replace(np.nan, str(""), regex=True)
        shard_num += 1
        yield df